Skip to content

Commit c3eb3e7

Browse files
authored
[BlockSparseArrays] Improve the design of block views (#1481)
* [NDTensors] Bump to v0.3.19
1 parent 83546e7 commit c3eb3e7

File tree

13 files changed

+222
-30
lines changed

13 files changed

+222
-30
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.18"
4+
version = "0.3.19"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8080
@testset "dual axes" begin
8181
r = gradedrange([U1(0) => 2, U1(1) => 2])
8282
a = BlockSparseArray{elt}(dual(r), r)
83-
a[Block(1, 1)] = randn(size(a[Block(1, 1)]))
84-
a[Block(2, 2)] = randn(size(a[Block(2, 2)]))
83+
a[Block(1, 1)] = randn(elt, size(a[Block(1, 1)]))
84+
a[Block(2, 2)] = randn(elt, size(a[Block(2, 2)]))
8585
a_dense = Array(a)
8686
@test eachindex(a) == CartesianIndices(size(a))
8787
for I in eachindex(a)

NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using BlockArrays:
66
BlockRange,
77
BlockedUnitRange,
88
BlockVector,
9+
BlockSlice,
910
block,
1011
blockaxes,
1112
blockedrange,
@@ -29,6 +30,36 @@ function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange)
2930
return only(axes(blockedunitrange_getindices(a, indices)))
3031
end
3132

33+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
34+
# Outputs a `BlockUnitRange`.
35+
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockRange{1}})
36+
return sub_axis(a, indices.block)
37+
end
38+
39+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
40+
# Outputs a `BlockUnitRange`.
41+
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}})
42+
return sub_axis(a, Block(indices))
43+
end
44+
45+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
46+
# Outputs a `BlockUnitRange`.
47+
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockIndexRange{1}})
48+
return sub_axis(a, indices.block)
49+
end
50+
51+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
52+
# Outputs a `BlockUnitRange`.
53+
function sub_axis(a::AbstractUnitRange, indices::Block)
54+
return only(axes(blockedunitrange_getindices(a, indices)))
55+
end
56+
57+
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
58+
# Outputs a `BlockUnitRange`.
59+
function sub_axis(a::AbstractUnitRange, indices::BlockIndexRange)
60+
return only(axes(blockedunitrange_getindices(a, indices)))
61+
end
62+
3263
# TODO: Use `GradedAxes.blockedunitrange_getindices`.
3364
# Outputs a `BlockUnitRange`.
3465
function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block})
@@ -131,6 +162,14 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
131162
return blockrange(axis, r.block)
132163
end
133164

165+
function blockrange(axis::AbstractUnitRange, r::Block{1})
166+
return r:r
167+
end
168+
169+
function blockrange(axis::AbstractUnitRange, r::BlockIndexRange)
170+
return Block(r):Block(r)
171+
end
172+
134173
function blockrange(axis::AbstractUnitRange, r)
135174
return error("Slicing not implemented for range of type `$(typeof(r))`.")
136175
end

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,18 @@ end
1616

1717
# Materialize a SubArray view.
1818
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
19+
# TODO: Make more generic for GPU.
1920
a_dest = BlockSparseArray{eltype(a)}(axes)
2021
a_dest .= a
2122
return a_dest
2223
end
24+
25+
# Materialize a SubArray view.
26+
function ArrayLayouts.sub_materialize(
27+
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
28+
)
29+
# TODO: Make more generic for GPU.
30+
a_dest = Array{eltype(a)}(undef, length.(axes))
31+
a_dest .= a
32+
return a_dest
33+
end

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ function SparseArrayInterface.sparse_map!(
3232
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
3333
BI_dest = blockindexrange(a_dest, I)
3434
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
35-
block_dest = @view a_dest[_block(BI_dest)]
36-
block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
35+
# TODO: Investigate why this doesn't work:
36+
# block_dest = @view a_dest[_block(BI_dest)]
37+
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
38+
# TODO: Investigate why this doesn't work:
39+
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
40+
block_srcs = ntuple(length(a_srcs)) do i
41+
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
42+
end
3743
subblock_dest = @view block_dest[BI_dest.indices...]
3844
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
3945
# TODO: Use `map!!` to handle immutable blocks.

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
using BlockArrays: BlockIndexRange, BlockRange, BlockSlice, block
22

3-
function blocksparse_view(a::AbstractArray, index::Block)
4-
return blocks(a)[Int.(Tuple(index))...]
5-
end
6-
73
# TODO: Define `AnyBlockSparseVector`.
84
function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N}
95
return blocksparse_view(a, index)

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,20 @@ function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitR
5555
return ArrayLayouts.layout_getindex(a, I...)
5656
end
5757

58-
function Base.isassigned(a::BlockSparseArrayLike, index::Vararg{Block})
58+
function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, block::Block{N}) where {N}
59+
return blocksparse_getindex(a, block)
60+
end
61+
function Base.getindex(
62+
a::BlockSparseArrayLike{<:Any,N}, block::Vararg{Block{1},N}
63+
) where {N}
64+
return blocksparse_getindex(a, block...)
65+
end
66+
67+
# TODO: Define `issasigned(a, ::Block{N})`.
68+
function Base.isassigned(
69+
a::BlockSparseArrayLike{<:Any,N}, index::Vararg{Block{1},N}
70+
) where {N}
71+
# TODO: Define `blocksparse_isassigned`.
5972
return isassigned(blocks(a), Int.(index)...)
6073
end
6174

@@ -64,6 +77,12 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N
6477
return a
6578
end
6679

80+
function Base.setindex!(
81+
a::BlockSparseArrayLike{<:Any,N}, value, I::Vararg{Block{1},N}
82+
) where {N}
83+
a[Block(Int.(I))] = value
84+
return a
85+
end
6786
function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::Block{N}) where {N}
6887
blocksparse_setindex!(a, value, I)
6988
return a

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where
1919
return a[findblockindex.(axes(a), I)...]
2020
end
2121

22+
function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
23+
return blocksparse_getindex(a, Tuple(I)...)
24+
end
25+
function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
26+
# TODO: Avoid copy if the block isn't stored.
27+
return copy(blocks(a)[Int.(I)...])
28+
end
29+
2230
# TODO: Implement as `copy(@view a[I...])`, which is then implemented
2331
# through `ArrayLayouts.sub_materialize`.
2432
using ..SparseArrayInterface: set_getindex_zero_function
@@ -59,21 +67,41 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N
5967
end
6068

6169
function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N}) where {N}
62-
a_b = view(a, block(I))
70+
i = Int.(Tuple(block(I)))
71+
a_b = blocks(a)[i...]
6372
a_b[I.α...] = value
64-
# Set the block, required if it is structurally zero
65-
a[block(I)] = a_b
73+
# Set the block, required if it is structurally zero.
74+
blocks(a)[i...] = a_b
6675
return a
6776
end
6877

6978
function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) where {N}
70-
# TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`.
71-
i = I.n
79+
blocksparse_setindex!(a, value, Tuple(I)...)
80+
return a
81+
end
82+
function blocksparse_setindex!(
83+
a::AbstractArray{<:Any,N}, value, I::Vararg{Block{1},N}
84+
) where {N}
85+
i = Int.(I)
7286
@boundscheck blockcheckbounds(a, i...)
87+
# TODO: Use `blocksizes(a)[i...]` when we upgrade to
88+
# BlockArrays.jl v1.
89+
if size(value) size(view(a, I...))
90+
return throw(
91+
DimensionMismatch("Trying to set a block with an array of the wrong size.")
92+
)
93+
end
7394
blocks(a)[i...] = value
7495
return a
7596
end
7697

98+
function blocksparse_view(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
99+
return blocksparse_view(a, Tuple(I)...)
100+
end
101+
function blocksparse_view(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
102+
return SubArray(a, to_indices(a, I))
103+
end
104+
77105
function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
78106
# TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`.
79107
i = I.n

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra: mul!
44
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored, block_reshape
55
using NDTensors.SparseArrayInterface: nstored
66
using NDTensors.TensorAlgebra: contract
7-
using Test: @test, @testset, @test_broken
7+
using Test: @test, @test_broken, @test_throws, @testset
88
include("TestBlockSparseArraysUtils.jl")
99
@testset "BlockSparseArrays (eltype=$elt)" for elt in
1010
(Float32, Float64, ComplexF32, ComplexF64)
@@ -20,6 +20,7 @@ include("TestBlockSparseArraysUtils.jl")
2020
@test block_nstored(a) == 0
2121
@test iszero(a)
2222
@test all(I -> iszero(a[I]), eachindex(a))
23+
@test_throws DimensionMismatch a[Block(1, 1)] = randn(elt, 2, 3)
2324

2425
a = BlockSparseArray{elt}([2, 3], [2, 3])
2526
a[3, 3] = 33
@@ -225,36 +226,59 @@ include("TestBlockSparseArraysUtils.jl")
225226
@test block_nstored(c) == 2
226227
@test Array(c) == 2 * transpose(Array(a))
227228

228-
## Broken, need to fix.
229-
230229
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
231230
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
232231
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
233-
@test_broken a[Block(1), Block(1):Block(2)]
232+
b = a[Block(1), Block(1):Block(2)]
233+
@test size(b) == (2, 7)
234+
@test blocksize(b) == (1, 2)
235+
@test b[Block(1, 1)] == a[Block(1, 1)]
236+
@test b[Block(1, 2)] == a[Block(1, 2)]
234237

235-
# This is outputting only zero blocks.
236238
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
237239
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
238240
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
239-
b = a[Block(2):Block(2), Block(1):Block(2)]
240-
@test_broken block_nstored(b) == 1
241-
@test_broken b == Array(a)[3:5, 1:end]
241+
b = copy(a)
242+
x = randn(elt, size(@view(a[Block(2, 2)])))
243+
b[Block(2), Block(2)] = x
244+
@test b[Block(2, 2)] == x
242245

243246
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
244247
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
245248
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
246249
b = copy(a)
247-
x = randn(size(@view(a[Block(2, 2)])))
248-
b[Block(2), Block(2)] = x
249-
@test_broken b[Block(2, 2)] == x
250+
b[Block(1, 1)] .= 1
251+
# TODO: Use `blocksizes(b)[1, 1]` once we upgrade to
252+
# BlockArrays.jl v1.
253+
@test b[Block(1, 1)] == trues(size(@view(b[Block(1, 1)])))
250254

251-
# Doesnt' set the block
255+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
256+
x = randn(elt, 1, 2)
257+
@view(a[Block(2, 2)])[1:1, 1:2] = x
258+
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
259+
@test a[Block(2, 2)][1:1, 1:2] == x
260+
261+
# TODO: This is broken, fix!
262+
@test_broken a[3:3, 4:5] == x
263+
264+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
265+
x = randn(elt, 1, 2)
266+
@views a[Block(2, 2)][1:1, 1:2] = x
267+
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
268+
@test a[Block(2, 2)][1:1, 1:2] == x
269+
270+
# TODO: This is broken, fix!
271+
@test_broken a[3:3, 4:5] == x
272+
273+
## Broken, need to fix.
274+
275+
# This is outputting only zero blocks.
252276
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
253277
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
254278
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
255-
b = copy(a)
256-
b[Block(1, 1)] .= 1
257-
@test_broken b[1, 1] == trues(size(@view(b[1, 1])))
279+
b = a[Block(2):Block(2), Block(1):Block(2)]
280+
@test_broken block_nstored(b) == 1
281+
@test_broken b == Array(a)[3:5, 1:end]
258282
end
259283
@testset "LinearAlgebra" begin
260284
a1 = BlockSparseArray{elt}([2, 3], [2, 3])

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ function blockedunitrange_getindices(
221221
return mortar(map(index -> a[index], indices))
222222
end
223223

224+
# TODO: Move this to a `BlockArraysExtensions` library.
225+
function blockedunitrange_getindices(a::BlockedUnitRange, indices::Block{1})
226+
return a[indices]
227+
end
228+
224229
# TODO: Move this to a `BlockArraysExtensions` library.
225230
function blockedunitrange_getindices(a::BlockedUnitRange, indices)
226231
return error("Not implemented.")

NDTensors/src/lib/GradedAxes/src/unitrangedual.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ function unitrangedual_getindices_blocks(a, indices)
3838
return mortar([dual(b) for b in blocks(a_indices)])
3939
end
4040

41+
# TODO: Move this to a `BlockArraysExtensions` library.
42+
function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1})
43+
return a[indices]
44+
end
45+
4146
function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}})
4247
return unitrangedual_getindices_blocks(a, indices)
4348
end
@@ -54,6 +59,12 @@ function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange)
5459
return BlockSlice(b, unlabel(a))
5560
end
5661

62+
using BlockArrays: BlockArrays, BlockSlice
63+
using NDTensors.GradedAxes: UnitRangeDual, dual
64+
function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual)
65+
return BlockSlice(b, dual(r))
66+
end
67+
5768
using NDTensors.LabelledNumbers: LabelledNumbers, label
5869
LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a)))
5970

NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,32 @@ Base.:-(x::LabelledInteger) = labelled_minus(x)
8686
# TODO: This is only needed for older Julia versions, like Julia 1.6.
8787
# Delete once we drop support for older Julia versions.
8888
Base.hash(x::LabelledInteger, h::UInt64) = labelled_hash(x, h)
89+
90+
using Random: AbstractRNG, default_rng
91+
default_eltype() = Float64
92+
for f in [:rand, :randn]
93+
@eval begin
94+
function Base.$f(
95+
rng::AbstractRNG,
96+
elt::Type{<:Number},
97+
dims::Tuple{LabelledInteger,Vararg{LabelledInteger}},
98+
)
99+
return a = $f(rng, elt, unlabel.(dims))
100+
end
101+
function Base.$f(
102+
rng::AbstractRNG,
103+
elt::Type{<:Number},
104+
dim1::LabelledInteger,
105+
dims::Vararg{LabelledInteger},
106+
)
107+
return $f(rng, elt, (dim1, dims...))
108+
end
109+
Base.$f(elt::Type{<:Number}, dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
110+
$f(default_rng(), elt, dims)
111+
Base.$f(elt::Type{<:Number}, dim1::LabelledInteger, dims::Vararg{LabelledInteger}) =
112+
$f(elt, (dim1, dims...))
113+
Base.$f(dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) =
114+
$f(default_eltype(), dims)
115+
Base.$f(dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = $f((dim1, dims...))
116+
end
117+
end

0 commit comments

Comments
 (0)