Skip to content

Commit 0ea81a7

Browse files
authored
Adapt to TensorProducts.jl (#89)
1 parent cac9639 commit 0ea81a7

File tree

4 files changed

+39
-48
lines changed

4 files changed

+39
-48
lines changed

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -21,12 +21,12 @@ SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
2121
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2222

2323
[weakdeps]
24-
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
2524
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
25+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2626

2727
[extensions]
2828
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
29-
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]
29+
BlockSparseArraysTensorAlgebraExt = ["TensorProducts", "TensorAlgebra"]
3030

3131
[compat]
3232
Adapt = "4.1.1"
@@ -38,14 +38,14 @@ DiagonalArrays = "0.3"
3838
Dictionaries = "0.4.3"
3939
FillArrays = "1.13.0"
4040
GPUArraysCore = "0.1.0, 0.2"
41-
GradedUnitRanges = "0.1.0"
42-
LabelledNumbers = "0.1.0"
41+
GradedUnitRanges = "0.2.2"
4342
LinearAlgebra = "1.10"
4443
MacroTools = "0.5.13"
4544
MapBroadcast = "0.1.5"
4645
SparseArraysBase = "0.5"
4746
SplitApplyCombine = "1.2.3"
48-
TensorAlgebra = "0.1.0, 0.2"
47+
TensorAlgebra = "0.2.4"
48+
TensorProducts = "0.1.2"
4949
Test = "1.10"
5050
TypeParameterAccessors = "0.2.0, 0.3"
5151
julia = "1.10"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
module BlockSparseArraysTensorAlgebraExt
22
using BlockArrays: AbstractBlockedUnitRange
3-
using GradedUnitRanges: tensor_product
4-
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
53

6-
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
7-
return tensor_product(a1, a2)
8-
end
4+
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5+
using TensorProducts: OneToOne
96

10-
using BlockArrays: AbstractBlockedUnitRange
117
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
12-
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
138

149
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
1510

@@ -46,13 +41,12 @@ using DerivableInterfaces: @interface
4641
using GradedUnitRanges:
4742
GradedUnitRanges,
4843
AbstractGradedUnitRange,
49-
OneToOne,
5044
blockmergesortperm,
5145
blocksortperm,
5246
dual,
5347
invblockperm,
5448
nondual,
55-
tensor_product
49+
unmerged_tensor_product
5650
using LinearAlgebra: Adjoint, Transpose
5751
using TensorAlgebra:
5852
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
@@ -77,10 +71,17 @@ function block_mergesort(a::AbstractArray)
7771
end
7872

7973
function TensorAlgebra.fusedims(
80-
::SectorFusion, a::AbstractArray, axes::AbstractUnitRange...
74+
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
8175
)
8276
# First perform a fusion using a block reshape.
83-
a_reshaped = fusedims(BlockReshapeFusion(), a, axes...)
77+
# TODO avoid groupreducewhile. Require refactor of fusedims.
78+
unmerged_axes = groupreducewhile(
79+
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
80+
) do i, axis
81+
return length(axis) length(merged_axes[i])
82+
end
83+
84+
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
8485
# Sort the blocks by sector and merge the equivalent sectors.
8586
return block_mergesort(a_reshaped)
8687
end
@@ -90,10 +91,11 @@ function TensorAlgebra.splitdims(
9091
)
9192
# First, fuse axes to get `blockmergesortperm`.
9293
# Then unpermute the blocks.
93-
axes_prod =
94-
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
95-
return length(axis) length(axes(a, i))
96-
end
94+
axes_prod = groupreducewhile(
95+
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
96+
) do i, axis
97+
return length(axis) length(axes(a, i))
98+
end
9799
blockperms = blocksortperm.(axes_prod)
98100
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
99101

@@ -106,34 +108,11 @@ function TensorAlgebra.splitdims(
106108
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
107109
end
108110

109-
# This is a temporary fix for `eachindex` being broken for BlockSparseArrays
110-
# with mixed dual and non-dual axes. This shouldn't be needed once
111-
# GradedUnitRanges is rewritten using BlockArrays v1.
112-
# TODO: Delete this once GradedUnitRanges is rewritten.
113-
function Base.eachindex(a::AbstractBlockSparseArray)
114-
return CartesianIndices(nondual.(axes(a)))
115-
end
116-
117111
# TODO: Handle this through some kind of trait dispatch, maybe
118112
# a `SymmetryStyle`-like trait to check if the block sparse
119113
# matrix has graded axes.
120114
function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
121115
return dual.(reverse(axes(a')))
122116
end
123117

124-
# This definition is only needed since calls like
125-
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
126-
# returns a `BlockSparseVector` instead of a `BlockVector`
127-
# due to limitations in the `BlockArray` type not allowing
128-
# axes with non-Int element types.
129-
# TODO: Remove this once that issue is fixed,
130-
# see https://github.com/JuliaArrays/BlockArrays.jl/pull/405.
131-
using BlockArrays: BlockRange
132-
using LabelledNumbers: label
133-
function GradedUnitRanges.blocklabels(a::BlockSparseVector)
134-
return map(BlockRange(a)) do block
135-
return label(blocks(a)[Int(block)])
136-
end
137-
end
138-
139118
end

test/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ BlockArrays = "1"
2929
BlockSparseArrays = "0.3"
3030
DiagonalArrays = "0.3"
3131
GPUArraysCore = "0.2"
32-
GradedUnitRanges = "0.1"
32+
GradedUnitRanges = "0.2.2"
3333
JLArrays = "0.2"
3434
LabelledNumbers = "0.1"
3535
LinearAlgebra = "1"
@@ -38,8 +38,8 @@ Random = "1"
3838
SafeTestsets = "0.1"
3939
SparseArraysBase = "0.5"
4040
Suppressor = "0.2"
41-
SymmetrySectors = "0.1"
42-
TensorAlgebra = "0.2"
41+
SymmetrySectors = "0.1.7"
42+
TensorAlgebra = "0.2.4"
4343
Test = "1"
4444
TestExtras = "0.3"
4545
TypeParameterAccessors = "0.3"

test/test_gradedunitrangesext.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
@eval module $(gensym())
21
using Test: @test, @testset
32
using BlockArrays:
43
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
@@ -10,6 +9,7 @@ using GradedUnitRanges:
109
GradedUnitRange,
1110
GradedUnitRangeDual,
1211
blocklabels,
12+
dag,
1313
dual,
1414
gradedrange,
1515
isdual
@@ -19,6 +19,7 @@ using SymmetrySectors: U1
1919
using TensorAlgebra: fusedims, splitdims
2020
using LinearAlgebra: adjoint
2121
using Random: randn!
22+
2223
function randn_blockdiagonal(elt::Type, axes::Tuple)
2324
a = BlockSparseArray{elt}(undef, axes)
2425
blockdiaglength = minimum(blocksize(a))
@@ -390,4 +391,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
390391
@test all(GradedUnitRanges.space_isequal.(axes(b), (r, dual(r))))
391392
end
392393
end
394+
395+
@testset "dag" begin
396+
elt = ComplexF64
397+
r = gradedrange([U1(0) => 2, U1(1) => 3])
398+
a = BlockSparseArray{elt}(undef, r, dual(r))
399+
a[Block(1, 1)] = randn(elt, 2, 2)
400+
a[Block(2, 2)] = randn(elt, 3, 3)
401+
@test isdual.(axes(a)) == (false, true)
402+
ad = dag(a)
403+
@test Array(ad) == conj(Array(a))
404+
@test isdual.(axes(ad)) == (true, false)
393405
end

0 commit comments

Comments
 (0)