Skip to content

Commit 957f2af

Browse files
authored
[BlockSparseArrays] More general broadcasting and slicing (#1332)
1 parent 093d339 commit 957f2af

File tree

25 files changed

+824
-126
lines changed

25 files changed

+824
-126
lines changed
Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,74 @@
11
@eval module $(gensym())
2+
using Compat: Returns
23
using Test: @test, @testset, @test_broken
34
using BlockArrays: Block, blocksize
4-
using NDTensors.BlockSparseArrays: BlockSparseArray
5-
using NDTensors.GradedAxes: gradedrange
5+
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
6+
using NDTensors.GradedAxes: GradedUnitRange, gradedrange
7+
using NDTensors.LabelledNumbers: label
68
using NDTensors.Sectors: U1
9+
using NDTensors.SparseArrayInterface: nstored
710
using NDTensors.TensorAlgebra: fusedims, splitdims
811
using Random: randn!
12+
function blockdiagonal!(f, a::AbstractArray)
13+
for i in 1:minimum(blocksize(a))
14+
b = Block(ntuple(Returns(i), ndims(a)))
15+
a[b] = f(a[b])
16+
end
17+
return a
18+
end
919
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1020
@testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts
11-
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
12-
d2 = gradedrange([U1(1) => 1, U1(0) => 1])
13-
a = BlockSparseArray{elt}(d1, d2, d1, d2)
14-
for i in 1:minimum(blocksize(a))
15-
b = Block(i, i, i, i)
16-
a[b] = randn!(a[b])
21+
@testset "map" begin
22+
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
23+
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
24+
a = BlockSparseArray{elt}(d1, d2, d1, d2)
25+
blockdiagonal!(randn!, a)
26+
27+
for b in (a + a, 2 * a)
28+
@test size(b) == (4, 4, 4, 4)
29+
@test blocksize(b) == (2, 2, 2, 2)
30+
@test nstored(b) == 32
31+
@test block_nstored(b) == 2
32+
# TODO: Have to investigate why this fails
33+
# on Julia v1.6, or drop support for v1.6.
34+
for i in 1:ndims(a)
35+
@test axes(b, i) isa GradedUnitRange
36+
end
37+
@test label(axes(b, 1)[Block(1)]) == U1(0)
38+
@test label(axes(b, 1)[Block(2)]) == U1(1)
39+
@test Array(a) isa Array{elt}
40+
@test Array(a) == a
41+
@test 2 * Array(a) == b
42+
end
43+
44+
b = a[2:3, 2:3, 2:3, 2:3]
45+
@test size(b) == (2, 2, 2, 2)
46+
@test blocksize(b) == (2, 2, 2, 2)
47+
@test nstored(b) == 2
48+
@test block_nstored(b) == 2
49+
for i in 1:ndims(a)
50+
@test axes(b, i) isa GradedUnitRange
51+
end
52+
@test label(axes(b, 1)[Block(1)]) == U1(0)
53+
@test label(axes(b, 1)[Block(2)]) == U1(1)
54+
@test Array(a) isa Array{elt}
55+
@test Array(a) == a
56+
end
57+
# TODO: Add tests for various slicing operations.
58+
@testset "fusedims" begin
59+
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
60+
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
61+
a = BlockSparseArray{elt}(d1, d2, d1, d2)
62+
blockdiagonal!(randn!, a)
63+
m = fusedims(a, (1, 2), (3, 4))
64+
@test axes(m, 1) isa GradedUnitRange
65+
@test axes(m, 2) isa GradedUnitRange
66+
@test a[1, 1, 1, 1] == m[1, 1]
67+
@test a[2, 2, 2, 2] == m[4, 4]
68+
# TODO: Current `fusedims` doesn't merge
69+
# common sectors, need to fix.
70+
@test_broken blocksize(m) == (3, 3)
71+
@test a == splitdims(m, (d1, d2), (d1, d2))
1772
end
18-
m = fusedims(a, (1, 2), (3, 4))
19-
@test a[1, 1, 1, 1] == m[2, 2]
20-
@test a[2, 2, 2, 2] == m[3, 3]
21-
# TODO: Current `fusedims` doesn't merge
22-
# common sectors, need to fix.
23-
@test_broken blocksize(m) == (3, 3)
24-
@test a == splitdims(m, (d1, d2), (d1, d2))
2573
end
2674
end

NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,48 @@
1-
using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, blockedrange
1+
using BlockArrays:
2+
BlockArrays,
3+
AbstractBlockArray,
4+
AbstractBlockVector,
5+
Block,
6+
BlockRange,
7+
BlockedUnitRange,
8+
BlockVector,
9+
block,
10+
blockaxes,
11+
blockedrange,
12+
blockindex,
13+
blocks,
14+
findblock,
15+
findblockindex
16+
using Compat: allequal
217
using Dictionaries: Dictionary, Indices
18+
using ..GradedAxes: blockedunitrange_getindices
319
using ..SparseArrayInterface: stored_indices
420

21+
# Outputs a `BlockUnitRange`.
22+
function sub_axis(a::AbstractUnitRange, indices)
23+
return error("Not implemented")
24+
end
25+
26+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
27+
# Outputs a `BlockUnitRange`.
28+
function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange)
29+
return only(axes(blockedunitrange_getindices(a, indices)))
30+
end
31+
32+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
33+
# Outputs a `BlockUnitRange`.
34+
function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block})
35+
return blockedrange([length(a[index]) for index in indices])
36+
end
37+
38+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
39+
# TODO: Merge blocks.
40+
function sub_axis(a::AbstractUnitRange, indices::BlockVector{<:Block})
41+
# `collect` is needed here, otherwise a `PseudoBlockVector` is
42+
# constructed.
43+
return blockedrange([length(a[index]) for index in collect(indices)])
44+
end
45+
546
# TODO: Use `Tuple` conversion once
647
# BlockArrays.jl PR is merged.
748
block_to_cartesianindex(b::Block) = CartesianIndex(b.n)
@@ -38,3 +79,110 @@ end
3879
function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange})
3980
return block_reshape(a, axes)
4081
end
82+
83+
function cartesianindices(axes::Tuple, b::Block)
84+
return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes)))
85+
end
86+
87+
# Get the range within a block.
88+
function blockindexrange(axis::AbstractUnitRange, r::UnitRange)
89+
bi1 = findblockindex(axis, first(r))
90+
bi2 = findblockindex(axis, last(r))
91+
b = block(bi1)
92+
# Range must fall within a single block.
93+
@assert b == block(bi2)
94+
i1 = blockindex(bi1)
95+
i2 = blockindex(bi2)
96+
return b[i1:i2]
97+
end
98+
99+
function blockindexrange(
100+
axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N}
101+
) where {N}
102+
brs = blockindexrange.(axes, I.indices)
103+
b = Block(block.(brs))
104+
rs = map(br -> only(br.indices), brs)
105+
return b[rs...]
106+
end
107+
108+
function blockindexrange(a::AbstractArray, I::CartesianIndices)
109+
return blockindexrange(axes(a), I)
110+
end
111+
112+
# Get the blocks the range spans across.
113+
function blockrange(axis::AbstractUnitRange, r::UnitRange)
114+
return findblock(axis, first(r)):findblock(axis, last(r))
115+
end
116+
117+
function blockrange(axis::AbstractUnitRange, r::Int)
118+
error("Slicing with integer values isn't supported.")
119+
return findblock(axis, r)
120+
end
121+
122+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
123+
for b in r
124+
@assert b blockaxes(axis, 1)
125+
end
126+
return r
127+
end
128+
129+
using BlockArrays: BlockSlice
130+
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
131+
return blockrange(axis, r.block)
132+
end
133+
134+
function blockrange(axis::AbstractUnitRange, r)
135+
return error("Slicing not implemented for range of type `$(typeof(r))`.")
136+
end
137+
138+
function cartesianindices(a::AbstractArray, b::Block)
139+
return cartesianindices(axes(a), b)
140+
end
141+
142+
# Output which blocks of `axis` are contained within the unit range `range`.
143+
# The start and end points must match.
144+
function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange)
145+
# TODO: Add a test that the start and end points of the ranges match.
146+
return findblock(axis, first(range)):findblock(axis, last(range))
147+
end
148+
149+
function block_stored_indices(a::AbstractArray)
150+
return Block.(Tuple.(stored_indices(blocks(a))))
151+
end
152+
153+
_block(indices) = block(indices)
154+
_block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices)))
155+
156+
function combine_axes(as::Vararg{Tuple})
157+
@assert allequal(length.(as))
158+
ndims = length(first(as))
159+
return ntuple(ndims) do dim
160+
dim_axes = map(a -> a[dim], as)
161+
return reduce(BlockArrays.combine_blockaxes, dim_axes)
162+
end
163+
end
164+
165+
# Returns `BlockRange`
166+
# Convert the block of the axes to blocks of the subaxes.
167+
function subblocks(axes::Tuple, subaxes::Tuple, block::Block)
168+
@assert length(axes) == length(subaxes)
169+
return BlockRange(
170+
ntuple(length(axes)) do dim
171+
findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]])
172+
end,
173+
)
174+
end
175+
176+
# Returns `Vector{<:Block}`
177+
function subblocks(axes::Tuple, subaxes::Tuple, blocks)
178+
return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block
179+
return vec(subblocks(axes, subaxes, block))
180+
end
181+
end
182+
183+
# Returns `Vector{<:CartesianIndices}`
184+
function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
185+
return map(subblocks(axes, subaxes, blocks)) do block
186+
return cartesianindices(subaxes, block)
187+
end
188+
end
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
using BlockArrays: AbstractBlockArray, BlocksView
12
using ..SparseArrayInterface: SparseArrayInterface, nstored
23

34
function SparseArrayInterface.nstored(a::AbstractBlockArray)
45
return sum(b -> nstored(b), blocks(a); init=zero(Int))
56
end
7+
8+
# TODO: Handle `BlocksView` wrapping a sparse array?
9+
function SparseArrayInterface.storage_indices(a::BlocksView)
10+
return CartesianIndices(a)
11+
end

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
33
include("blocksparsearrayinterface/linearalgebra.jl")
44
include("blocksparsearrayinterface/blockzero.jl")
55
include("blocksparsearrayinterface/broadcast.jl")
6+
include("blocksparsearrayinterface/arraylayouts.jl")
67
include("abstractblocksparsearray/abstractblocksparsearray.jl")
78
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
89
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
910
include("abstractblocksparsearray/abstractblocksparsevector.jl")
11+
include("abstractblocksparsearray/view.jl")
1012
include("abstractblocksparsearray/arraylayouts.jl")
1113
include("abstractblocksparsearray/sparsearrayinterface.jl")
1214
include("abstractblocksparsearray/linearalgebra.jl")

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,26 @@ Base.axes(::AbstractBlockSparseArray) = error("Not implemented")
1414

1515
blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented")
1616

17-
# Specialized in order to fix ambiguity error with `BlockArrays`.
17+
## # Specialized in order to fix ambiguity error with `BlockArrays`.
1818
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
1919
return blocksparse_getindex(a, I...)
2020
end
2121

22-
# Fix ambiguity error with `BlockArrays`.
23-
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
24-
return ArrayLayouts.layout_getindex(a, I)
25-
end
26-
27-
# Fix ambiguity error with `BlockArrays`.
28-
function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1})
29-
return ArrayLayouts.layout_getindex(a, I)
30-
end
31-
32-
# Fix ambiguity error with `BlockArrays`.
33-
function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
34-
return blocksparse_getindex(a, I...)
35-
end
22+
## # Fix ambiguity error with `BlockArrays`.
23+
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
24+
## return ArrayLayouts.layout_getindex(a, I)
25+
## end
26+
##
27+
## # Fix ambiguity error with `BlockArrays`.
28+
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1})
29+
## return ArrayLayouts.layout_getindex(a, I)
30+
## end
31+
##
32+
## # Fix ambiguity error with `BlockArrays`.
33+
## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
34+
## ## return blocksparse_getindex(a, I...)
35+
## return ArrayLayouts.layout_getindex(a, I...)
36+
## end
3637

3738
# Specialized in order to fix ambiguity error with `BlockArrays`.
3839
function Base.setindex!(
Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
using ArrayLayouts: ArrayLayouts, MemoryLayout, MatMulMatAdd, MulAdd
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
44
using LinearAlgebra: mul!
55

6-
# TODO: Generalize to `BlockSparseArrayLike`.
7-
function ArrayLayouts.MemoryLayout(arraytype::Type{<:AbstractBlockSparseArray})
6+
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
87
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
98
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
109
return BlockLayout{outer_layout,inner_layout}()
@@ -16,14 +15,9 @@ function Base.similar(
1615
return similar(BlockSparseArray{elt}, axes)
1716
end
1817

19-
function ArrayLayouts.materialize!(
20-
m::MatMulMatAdd{
21-
<:BlockLayout{<:SparseLayout},
22-
<:BlockLayout{<:SparseLayout},
23-
<:BlockLayout{<:SparseLayout},
24-
},
25-
)
26-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
27-
mul!(a_dest, a1, a2, α, β)
18+
# Materialize a SubArray view.
19+
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
20+
a_dest = BlockSparseArray{eltype(a)}(axes)
21+
a_dest .= a
2822
return a_dest
2923
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,48 @@
1+
using BlockArrays: BlockedUnitRange, BlockSlice
12
using Base.Broadcast: Broadcast
23

34
function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike})
45
return BlockSparseArrayStyle{ndims(arraytype)}()
56
end
7+
8+
# Fix ambiguity error with `BlockArrays`.
9+
function Broadcast.BroadcastStyle(
10+
arraytype::Type{
11+
<:SubArray{
12+
<:Any,
13+
<:Any,
14+
<:AbstractBlockSparseArray,
15+
<:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
16+
},
17+
},
18+
)
19+
return BlockSparseArrayStyle{ndims(arraytype)}()
20+
end
21+
function Broadcast.BroadcastStyle(
22+
arraytype::Type{
23+
<:SubArray{
24+
<:Any,
25+
<:Any,
26+
<:AbstractBlockSparseArray,
27+
<:Tuple{
28+
BlockSlice{<:Any,<:BlockedUnitRange},
29+
BlockSlice{<:Any,<:BlockedUnitRange},
30+
Vararg{Any},
31+
},
32+
},
33+
},
34+
)
35+
return BlockSparseArrayStyle{ndims(arraytype)}()
36+
end
37+
function Broadcast.BroadcastStyle(
38+
arraytype::Type{
39+
<:SubArray{
40+
<:Any,
41+
<:Any,
42+
<:AbstractBlockSparseArray,
43+
<:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
44+
},
45+
},
46+
)
47+
return BlockSparseArrayStyle{ndims(arraytype)}()
48+
end

0 commit comments

Comments
 (0)