Skip to content

Commit 49c1202

Browse files
authored
[BlockSparseArrays] Zero dimensional block sparse array and some fixes for Adjoint and PermutedDimsArray (#1574)
* [BlockSparseArrays] Zero dimensional block sparse array and some fixes for Adjoint and PermutedDimsArray * [NDTensors] Bump to v0.3.59
1 parent f9b6309 commit 49c1202

File tree

11 files changed

+159
-6
lines changed

11 files changed

+159
-6
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.58"
4+
version = "0.3.59"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function Base.show(
138138
io::IO, mime::MIME"text/plain", a::Adjoint{<:Any,<:BlockSparseMatrix}; kwargs...
139139
)
140140
axes_a = axes(a)
141-
a_nondual = BlockSparseArray(blocks(a'), dual.(nondual.(axes(a))))'
141+
a_nondual = BlockSparseArray(blocks(a'), dual.(nondual.(axes(a'))))'
142142
return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...)
143143
end
144144

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
@eval module $(gensym())
2-
using Compat: Returns
32
using Test: @test, @testset
43
using BlockArrays:
54
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
@@ -287,6 +286,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
287286
@test ax isa typeof(dual(r))
288287
end
289288

289+
@test !isdual(axes(a, 1))
290+
@test !isdual(axes(a, 2))
291+
@test isdual(axes(a', 1))
292+
@test isdual(axes(a', 2))
293+
@test isdual(axes(b, 1))
294+
@test isdual(axes(b, 2))
295+
@test isdual(axes(copy(a'), 1))
296+
@test isdual(axes(copy(a'), 2))
297+
290298
I = [Block(1)[1:1]]
291299
@test size(b[I, :]) == (1, 4)
292300
@test size(b[:, I]) == (4, 1)

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("abstractblocksparsearray/arraylayouts.jl")
1616
include("abstractblocksparsearray/sparsearrayinterface.jl")
1717
include("abstractblocksparsearray/broadcast.jl")
1818
include("abstractblocksparsearray/map.jl")
19+
include("abstractblocksparsearray/linearalgebra.jl")
1920
include("blocksparsearray/defaults.jl")
2021
include("blocksparsearray/blocksparsearray.jl")
2122
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@ function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N
2222
return SparseArrayDOK{AbstractArray{T,N},N}
2323
end
2424

25-
## # Specialized in order to fix ambiguity error with `BlockArrays`.
25+
# Specialized in order to fix ambiguity error with `BlockArrays`.
2626
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2727
return blocksparse_getindex(a, I...)
2828
end
2929

30+
# Specialized in order to fix ambiguity error with `BlockArrays`.
31+
function Base.getindex(a::AbstractBlockSparseArray{<:Any,0})
32+
return blocksparse_getindex(a)
33+
end
34+
3035
## # Fix ambiguity error with `BlockArrays`.
3136
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
3237
## return ArrayLayouts.layout_getindex(a, I)
@@ -51,6 +56,12 @@ function Base.setindex!(
5156
return a
5257
end
5358

59+
# Fix ambiguity error.
60+
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
61+
blocksparse_setindex!(a, value)
62+
return a
63+
end
64+
5465
function Base.setindex!(
5566
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
5667
) where {N}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using LinearAlgebra: Adjoint, Transpose
2+
3+
# Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184
4+
# but also takes the dual of the axes.
5+
# Fixes an issue raised in:
6+
# https://github.com/ITensor/ITensors.jl/issues/1336#issuecomment-2353434147
7+
function Base.copy(a::Adjoint{T,<:AbstractBlockSparseMatrix{T}}) where {T}
8+
a_dest = similar(parent(a), axes(a))
9+
a_dest .= a
10+
return a_dest
11+
end
12+
13+
# More efficient than the generic `LinearAlgebra` version.
14+
function Base.copy(a::Transpose{T,<:AbstractBlockSparseMatrix{T}}) where {T}
15+
a_dest = similar(parent(a), axes(a))
16+
a_dest .= a
17+
return a_dest
18+
end

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ function Base.getindex(
9292
)
9393
return ArrayLayouts.layout_getindex(a, I...)
9494
end
95+
# Fixes ambiguity error.
96+
function Base.getindex(a::BlockSparseArrayLike{<:Any,0})
97+
return ArrayLayouts.layout_getindex(a)
98+
end
9599

96100
# TODO: Define `blocksparse_isassigned`.
97101
function Base.isassigned(
@@ -100,6 +104,11 @@ function Base.isassigned(
100104
return isassigned(blocks(a), Int.(index)...)
101105
end
102106

107+
# Fix ambiguity error.
108+
function Base.isassigned(a::BlockSparseArrayLike{<:Any,0})
109+
return isassigned(blocks(a))
110+
end
111+
103112
function Base.isassigned(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N}
104113
return isassigned(a, Tuple(index)...)
105114
end
@@ -211,6 +220,11 @@ function Base.similar(
211220
return blocksparse_similar(a, elt, axes)
212221
end
213222

223+
# Fixes ambiguity error.
224+
function Base.similar(a::BlockSparseArrayLike{<:Any,0}, elt::Type, axes::Tuple{})
225+
return blocksparse_similar(a, elt, axes)
226+
end
227+
214228
# Fixes ambiguity error with `BlockArrays`.
215229
function Base.similar(
216230
a::BlockSparseArrayLike,
@@ -259,3 +273,22 @@ function Base.similar(
259273
)
260274
return blocksparse_similar(a, elt, axes)
261275
end
276+
277+
# TODO: Implement this in a more generic way using a smarter `copyto!`,
278+
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
279+
# These are defined for now to avoid scalar indexing issues when there
280+
# are blocks on GPU.
281+
function Base.Array{T,N}(a::BlockSparseArrayLike{<:Any,N}) where {T,N}
282+
# First make it dense, then move to CPU.
283+
# Directly copying to CPU causes some issues with
284+
# scalar indexing on GPU which we have to investigate.
285+
a_dest = similartype(blocktype(a), T)(undef, size(a))
286+
a_dest .= a
287+
return Array{T,N}(a_dest)
288+
end
289+
function Base.Array{T}(a::BlockSparseArrayLike) where {T}
290+
return Array{T,ndims(a)}(a)
291+
end
292+
function Base.Array(a::BlockSparseArrayLike)
293+
return Array{eltype(a)}(a)
294+
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {
7272
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
7373
end
7474

75+
function BlockSparseArray{T,0}(axes::Tuple{}) where {T}
76+
return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes)
77+
end
78+
7579
function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
7680
return BlockSparseArray{T,N}(blockedrange.(dims))
7781
end
@@ -84,6 +88,10 @@ function BlockSparseArray{T}(axes::Tuple{Vararg{AbstractUnitRange}}) where {T}
8488
return BlockSparseArray{T,length(axes)}(axes)
8589
end
8690

91+
function BlockSparseArray{T}(axes::Tuple{}) where {T}
92+
return BlockSparseArray{T,length(axes)}(axes)
93+
end
94+
8795
function BlockSparseArray{T}(dims::Vararg{Vector{Int}}) where {T}
8896
return BlockSparseArray{T}(dims)
8997
end
@@ -92,6 +100,10 @@ function BlockSparseArray{T}(axes::Vararg{AbstractUnitRange}) where {T}
92100
return BlockSparseArray{T}(axes)
93101
end
94102

103+
function BlockSparseArray{T}() where {T}
104+
return BlockSparseArray{T}(())
105+
end
106+
95107
function BlockSparseArray{T,N,A}(
96108
::UndefInitializer, dims::Tuple
97109
) where {T,N,A<:AbstractArray{T,N}}

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where
2222
return a[findblockindex.(axes(a), I)...]
2323
end
2424

25+
# Fix ambiguity error.
26+
function blocksparse_getindex(a::AbstractArray{<:Any,0})
27+
# TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430
28+
# is fixed.
29+
return a[BlockIndex{0,Tuple{},Tuple{}}((), ())]
30+
end
31+
2532
# a[1:2, 1:2]
2633
# TODO: This definition means that the result of slicing a block sparse array
2734
# with a non-blocked unit range is blocked. We may want to change that behavior,
@@ -77,6 +84,14 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N
7784
return a
7885
end
7986

87+
# Fix ambiguity error.
88+
function blocksparse_setindex!(a::AbstractArray{<:Any,0}, value)
89+
# TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430
90+
# is fixed.
91+
a[BlockIndex{0,Tuple{},Tuple{}}((), ())] = value
92+
return a
93+
end
94+
8095
function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N}) where {N}
8196
i = Int.(Tuple(block(I)))
8297
a_b = blocks(a)[i...]
@@ -86,6 +101,15 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N
86101
return a
87102
end
88103

104+
# Fix ambiguity error.
105+
function blocksparse_setindex!(a::AbstractArray{<:Any,0}, value, I::BlockIndex{0})
106+
a_b = blocks(a)[]
107+
a_b[] = value
108+
# Set the block, required if it is structurally zero.
109+
blocks(a)[] = a_b
110+
return a
111+
end
112+
89113
function blocksparse_fill!(a::AbstractArray, value)
90114
for b in BlockRange(a)
91115
# We can't use:
@@ -119,7 +143,8 @@ end
119143
using ..SparseArrayInterface:
120144
SparseArrayInterface, AbstractSparseArray, AbstractSparseMatrix
121145

122-
_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
146+
_perm(::PermutedDimsArray{<:Any,<:Any,perm}) where {perm} = perm
147+
_invperm(::PermutedDimsArray{<:Any,<:Any,<:Any,invperm}) where {invperm} = invperm
123148
_getindices(t::Tuple, indices) = map(i -> t[i], indices)
124149
_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices))
125150

@@ -140,7 +165,7 @@ function Base.getindex(
140165
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
141166
) where {N}
142167
return PermutedDimsArray(
143-
blocks(parent(a.array))[_getindices(index, _perm(a.array))...], _perm(a.array)
168+
blocks(parent(a.array))[_getindices(index, _invperm(a.array))...], _perm(a.array)
144169
)
145170
end
146171
function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks)

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,37 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
109109

110110
a[3, 3] = NaN
111111
@test isnan(norm(a))
112+
113+
# Empty constructor
114+
for a in (dev(BlockSparseArray{elt}()), dev(BlockSparseArray{elt}(undef)))
115+
@test size(a) == ()
116+
@test isone(length(a))
117+
@test blocksize(a) == ()
118+
@test blocksizes(a) == fill(())
119+
@test iszero(block_nstored(a))
120+
@test iszero(@allowscalar(a[]))
121+
@test iszero(@allowscalar(a[CartesianIndex()]))
122+
@test a[Block()] == dev(fill(0))
123+
@test iszero(@allowscalar(a[Block()][]))
124+
# Broken:
125+
## @test b[Block()[]] == 2
126+
for b in (
127+
(b = copy(a); @allowscalar b[] = 2; b),
128+
(b = copy(a); @allowscalar b[CartesianIndex()] = 2; b),
129+
)
130+
@test size(b) == ()
131+
@test isone(length(b))
132+
@test blocksize(b) == ()
133+
@test blocksizes(b) == fill(())
134+
@test isone(block_nstored(b))
135+
@test @allowscalar(b[]) == 2
136+
@test @allowscalar(b[CartesianIndex()]) == 2
137+
@test b[Block()] == dev(fill(2))
138+
@test @allowscalar(b[Block()][]) == 2
139+
# Broken:
140+
## @test b[Block()[]] == 2
141+
end
142+
end
112143
end
113144
@testset "Tensor algebra" begin
114145
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
@@ -266,6 +297,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
266297
@test block_nstored(b) == 2
267298
@test nstored(b) == 2 * 4 + 3 * 3
268299

300+
a = dev(BlockSparseArray{elt}([1, 1, 1], [1, 2, 3], [2, 2, 1], [1, 2, 1]))
301+
a[Block(3, 2, 2, 3)] = dev(randn(elt, 1, 2, 2, 1))
302+
perm = (2, 3, 4, 1)
303+
for b in (PermutedDimsArray(a, perm), permutedims(a, perm))
304+
@test Array(b) == permutedims(Array(a), perm)
305+
@test issetequal(block_stored_indices(b), [Block(2, 2, 3, 3)])
306+
@test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm)
307+
end
308+
269309
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
270310
@views for b in [Block(1, 2), Block(2, 1)]
271311
a[b] = randn(elt, size(a[b]))

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ function sparse_getindex(a::AbstractArray, I::Vararg{Int})
8282
return sparse_getindex(a, CartesianIndex(I))
8383
end
8484

85+
# Fix ambiguity error.
86+
function sparse_getindex(a::AbstractArray{<:Any,0})
87+
return sparse_getindex(a, CartesianIndex())
88+
end
89+
8590
# Linear indexing
8691
function sparse_getindex(a::AbstractArray, I::CartesianIndex{1})
8792
return sparse_getindex(a, CartesianIndices(a)[I])

0 commit comments

Comments
 (0)