Skip to content

Commit 10a6563

Browse files
authored
[BlockSparseArrays] Define more constructors (#1586)
* [BlockSparseArrays] Define more constructors * [NDTensors] Bump to v0.3.68
1 parent 7faad33 commit 10a6563

File tree

4 files changed

+126
-27
lines changed

4 files changed

+126
-27
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.67"
4+
version = "0.3.68"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ struct BlockSparseArray{
1616
axes::Axes
1717
end
1818

19-
const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes}
20-
const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes}
19+
# TODO: Can this definition be shortened?
20+
const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange,AbstractUnitRange}} = BlockSparseArray{
21+
T,2,A,Blocks,Axes
22+
}
23+
24+
# TODO: Can this definition be shortened?
25+
const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange}} = BlockSparseArray{
26+
T,1,A,Blocks,Axes
27+
}
2128

2229
function BlockSparseArray(
2330
block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}},
@@ -68,10 +75,38 @@ function BlockSparseArray{T,N,A}(
6875
return BlockSparseArray{T,N,A}(blocks, axes)
6976
end
7077

78+
function BlockSparseArray{T,N,A}(
79+
axes::Vararg{AbstractUnitRange,N}
80+
) where {T,N,A<:AbstractArray{T,N}}
81+
return BlockSparseArray{T,N,A}(axes)
82+
end
83+
84+
function BlockSparseArray{T,N,A}(
85+
dims::Tuple{Vararg{Vector{Int},N}}
86+
) where {T,N,A<:AbstractArray{T,N}}
87+
return BlockSparseArray{T,N,A}(blockedrange.(dims))
88+
end
89+
90+
# Fix ambiguity error.
91+
function BlockSparseArray{T,0,A}(axes::Tuple{}) where {T,A<:AbstractArray{T,0}}
92+
blocks = default_blocks(A, axes)
93+
return BlockSparseArray{T,0,A}(blocks, axes)
94+
end
95+
96+
function BlockSparseArray{T,N,A}(
97+
dims::Vararg{Vector{Int},N}
98+
) where {T,N,A<:AbstractArray{T,N}}
99+
return BlockSparseArray{T,N,A}(dims)
100+
end
101+
71102
function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N}
72103
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
73104
end
74105

106+
function BlockSparseArray{T,N}(axes::Vararg{AbstractUnitRange,N}) where {T,N}
107+
return BlockSparseArray{T,N}(axes)
108+
end
109+
75110
function BlockSparseArray{T,0}(axes::Tuple{}) where {T}
76111
return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes)
77112
end
@@ -80,6 +115,10 @@ function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
80115
return BlockSparseArray{T,N}(blockedrange.(dims))
81116
end
82117

118+
function BlockSparseArray{T,N}(dims::Vararg{Vector{Int},N}) where {T,N}
119+
return BlockSparseArray{T,N}(dims)
120+
end
121+
83122
function BlockSparseArray{T}(dims::Tuple{Vararg{Vector{Int}}}) where {T}
84123
return BlockSparseArray{T,length(dims)}(dims)
85124
end
@@ -104,37 +143,25 @@ function BlockSparseArray{T}() where {T}
104143
return BlockSparseArray{T}(())
105144
end
106145

107-
function BlockSparseArray{T,N,A}(
108-
::UndefInitializer, dims::Tuple
109-
) where {T,N,A<:AbstractArray{T,N}}
110-
return BlockSparseArray{T,N,A}(dims)
111-
end
112-
113146
# undef
114-
function BlockSparseArray{T,N}(
115-
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
116-
) where {T,N}
117-
return BlockSparseArray{T,N}(axes)
118-
end
119-
120-
function BlockSparseArray{T,N}(
121-
::UndefInitializer, dims::Tuple{Vararg{Vector{Int},N}}
122-
) where {T,N}
123-
return BlockSparseArray{T,N}(dims)
147+
function BlockSparseArray{T,N,A,Blocks}(
148+
::UndefInitializer, args...
149+
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
150+
return BlockSparseArray{T,N,A,Blocks}(args...)
124151
end
125152

126-
function BlockSparseArray{T}(
127-
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange}}
128-
) where {T}
129-
return BlockSparseArray{T}(axes)
153+
function BlockSparseArray{T,N,A}(
154+
::UndefInitializer, args...
155+
) where {T,N,A<:AbstractArray{T,N}}
156+
return BlockSparseArray{T,N,A}(args...)
130157
end
131158

132-
function BlockSparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Vector{Int}}}) where {T}
133-
return BlockSparseArray{T}(dims)
159+
function BlockSparseArray{T,N}(::UndefInitializer, args...) where {T,N}
160+
return BlockSparseArray{T,N}(args...)
134161
end
135162

136-
function BlockSparseArray{T}(::UndefInitializer, dims::Vararg{Vector{Int}}) where {T}
137-
return BlockSparseArray{T}(dims...)
163+
function BlockSparseArray{T}(::UndefInitializer, args...) where {T}
164+
return BlockSparseArray{T}(args...)
138165
end
139166

140167
# Base `AbstractArray` interface

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ using ..SparseArrayInterface: perm, iperm, stored_length, sparse_zero!
1717

1818
blocksparse_blocks(a::AbstractArray) = error("Not implemented")
1919

20+
blockstype(a::AbstractArray) = blockstype(typeof(a))
21+
2022
function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2123
@boundscheck checkbounds(a, I...)
2224
return a[findblockindex.(axes(a), I)...]

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@ using LinearAlgebra: Adjoint, dot, mul!, norm
2020
using NDTensors.BlockSparseArrays:
2121
@view!,
2222
BlockSparseArray,
23+
BlockSparseMatrix,
24+
BlockSparseVector,
2325
BlockView,
2426
block_stored_length,
2527
block_reshape,
2628
block_stored_indices,
29+
blockstype,
30+
blocktype,
2731
view!
2832
using NDTensors.GPUArraysCoreExtensions: cpu
2933
using NDTensors.SparseArrayInterface: stored_length
34+
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK
3035
using NDTensors.TensorAlgebra: contract
3136
using Test: @test, @test_broken, @test_throws, @testset
3237
include("TestBlockSparseArraysUtils.jl")
@@ -72,6 +77,71 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
7277
ah = adjoint(a)
7378
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
7479
end
80+
@testset "Constructors" begin
81+
# BlockSparseMatrix
82+
bs = ([2, 3], [3, 4])
83+
for T in (
84+
BlockSparseArray{elt},
85+
BlockSparseArray{elt,2},
86+
BlockSparseMatrix{elt},
87+
BlockSparseArray{elt,2,Matrix{elt}},
88+
BlockSparseMatrix{elt,Matrix{elt}},
89+
## BlockSparseArray{elt,2,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO
90+
## BlockSparseMatrix{elt,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO
91+
)
92+
for args in (
93+
bs,
94+
(bs,),
95+
blockedrange.(bs),
96+
(blockedrange.(bs),),
97+
(undef, bs),
98+
(undef, bs...),
99+
(undef, blockedrange.(bs)),
100+
(undef, blockedrange.(bs)...),
101+
)
102+
a = T(args...)
103+
@test eltype(a) == elt
104+
@test blocktype(a) == Matrix{elt}
105+
@test blockstype(a) <: SparseMatrixDOK{Matrix{elt}}
106+
@test blocklengths.(axes(a)) == ([2, 3], [3, 4])
107+
@test iszero(a)
108+
@test iszero(block_stored_length(a))
109+
@test iszero(stored_length(a))
110+
end
111+
end
112+
113+
# BlockSparseVector
114+
bs = ([2, 3],)
115+
for T in (
116+
BlockSparseArray{elt},
117+
BlockSparseArray{elt,1},
118+
BlockSparseVector{elt},
119+
BlockSparseArray{elt,1,Vector{elt}},
120+
BlockSparseVector{elt,Vector{elt}},
121+
## BlockSparseArray{elt,1,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO
122+
## BlockSparseVector{elt,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO
123+
)
124+
for args in (
125+
bs,
126+
(bs,),
127+
blockedrange.(bs),
128+
(blockedrange.(bs),),
129+
(undef, bs),
130+
(undef, bs...),
131+
(undef, blockedrange.(bs)),
132+
(undef, blockedrange.(bs)...),
133+
)
134+
a = T(args...)
135+
@test eltype(a) == elt
136+
@test blocktype(a) == Vector{elt}
137+
@test blockstype(a) <: SparseVectorDOK{Vector{elt}}
138+
@test blocklengths.(axes(a)) == ([2, 3],)
139+
@test iszero(a)
140+
@test iszero(block_stored_length(a))
141+
@test iszero(stored_length(a))
142+
end
143+
end
144+
end
75145
@testset "Basics" begin
76146
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
77147
@allowscalar @test a == dev(

0 commit comments

Comments
 (0)