1
1
module BlockSparseArraysTensorAlgebraExt
2
2
using BlockArrays: AbstractBlockedUnitRange
3
- using GradedUnitRanges: tensor_product
4
- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5
3
6
- function TensorAlgebra.:⊗ (a1:: AbstractBlockedUnitRange , a2:: AbstractBlockedUnitRange )
7
- return tensor_product (a1, a2)
8
- end
4
+ using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5
+ using TensorProducts: OneToOne
9
6
10
- using BlockArrays: AbstractBlockedUnitRange
11
7
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
12
- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
13
8
14
9
TensorAlgebra. FusionStyle (:: AbstractBlockedUnitRange ) = BlockReshapeFusion ()
15
10
@@ -46,13 +41,12 @@ using DerivableInterfaces: @interface
46
41
using GradedUnitRanges:
47
42
GradedUnitRanges,
48
43
AbstractGradedUnitRange,
49
- OneToOne,
50
44
blockmergesortperm,
51
45
blocksortperm,
52
46
dual,
53
47
invblockperm,
54
48
nondual,
55
- tensor_product
49
+ unmerged_tensor_product
56
50
using LinearAlgebra: Adjoint, Transpose
57
51
using TensorAlgebra:
58
52
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
@@ -77,10 +71,17 @@ function block_mergesort(a::AbstractArray)
77
71
end
78
72
79
73
function TensorAlgebra. fusedims (
80
- :: SectorFusion , a:: AbstractArray , axes :: AbstractUnitRange...
74
+ :: SectorFusion , a:: AbstractArray , merged_axes :: AbstractUnitRange...
81
75
)
82
76
# First perform a fusion using a block reshape.
83
- a_reshaped = fusedims (BlockReshapeFusion (), a, axes... )
77
+ # TODO avoid groupreducewhile. Require refactor of fusedims.
78
+ unmerged_axes = groupreducewhile (
79
+ unmerged_tensor_product, axes (a), length (merged_axes); init= OneToOne ()
80
+ ) do i, axis
81
+ return length (axis) ≤ length (merged_axes[i])
82
+ end
83
+
84
+ a_reshaped = fusedims (BlockReshapeFusion (), a, unmerged_axes... )
84
85
# Sort the blocks by sector and merge the equivalent sectors.
85
86
return block_mergesort (a_reshaped)
86
87
end
@@ -90,10 +91,11 @@ function TensorAlgebra.splitdims(
90
91
)
91
92
# First, fuse axes to get `blockmergesortperm`.
92
93
# Then unpermute the blocks.
93
- axes_prod =
94
- groupreducewhile (tensor_product, split_axes, ndims (a); init= OneToOne ()) do i, axis
95
- return length (axis) ≤ length (axes (a, i))
96
- end
94
+ axes_prod = groupreducewhile (
95
+ unmerged_tensor_product, split_axes, ndims (a); init= OneToOne ()
96
+ ) do i, axis
97
+ return length (axis) ≤ length (axes (a, i))
98
+ end
97
99
blockperms = blocksortperm .(axes_prod)
98
100
sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
99
101
@@ -106,34 +108,11 @@ function TensorAlgebra.splitdims(
106
108
return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
107
109
end
108
110
109
- # This is a temporary fix for `eachindex` being broken for BlockSparseArrays
110
- # with mixed dual and non-dual axes. This shouldn't be needed once
111
- # GradedUnitRanges is rewritten using BlockArrays v1.
112
- # TODO : Delete this once GradedUnitRanges is rewritten.
113
- function Base. eachindex (a:: AbstractBlockSparseArray )
114
- return CartesianIndices (nondual .(axes (a)))
115
- end
116
-
117
111
# TODO : Handle this through some kind of trait dispatch, maybe
118
112
# a `SymmetryStyle`-like trait to check if the block sparse
119
113
# matrix has graded axes.
120
114
function Base. axes (a:: Adjoint{<:Any,<:AbstractBlockSparseMatrix} )
121
115
return dual .(reverse (axes (a' )))
122
116
end
123
117
124
- # This definition is only needed since calls like
125
- # `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
126
- # returns a `BlockSparseVector` instead of a `BlockVector`
127
- # due to limitations in the `BlockArray` type not allowing
128
- # axes with non-Int element types.
129
- # TODO : Remove this once that issue is fixed,
130
- # see https://github.com/JuliaArrays/BlockArrays.jl/pull/405.
131
- using BlockArrays: BlockRange
132
- using LabelledNumbers: label
133
- function GradedUnitRanges. blocklabels (a:: BlockSparseVector )
134
- return map (BlockRange (a)) do block
135
- return label (blocks (a)[Int (block)])
136
- end
137
- end
138
-
139
118
end
0 commit comments