Skip to content

Commit f7e162f

Browse files
authored
Fix a bug in blocktype (#40)
1 parent 88aa06f commit f7e162f

File tree

4 files changed

+76
-23
lines changed

4 files changed

+76
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.13"
4+
version = "0.2.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,52 @@ function eachstoredblock(a::AbstractArray)
4141
return storedvalues(blocks(a))
4242
end
4343

44-
# TODO: Generalize this, this catches simple cases
45-
# where the more general definition isn't specific enough.
46-
blocktype(a::Array) = typeof(a)
47-
# TODO: Maybe unwrap SubArrays?
44+
function blockstype(a::AbstractArray)
45+
return typeof(blocks(a))
46+
end
47+
48+
#=
49+
Ideally this would just be defined as `eltype(blockstype(a))`.
50+
However, BlockArrays.jl doesn't make `eltype(blocks(a))` concrete
51+
even when it could be
52+
(https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.4.0/src/blocks.jl#L71-L74):
53+
```julia
54+
julia> eltype(blocks(BlockArray(randn(2, 2), [1, 1], [1, 1])))
55+
Matrix{Float64} (alias for Array{Float64, 2})
56+
57+
julia> eltype(blocks(BlockedArray(randn(2, 2), [1, 1], [1, 1])))
58+
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
59+
60+
julia> eltype(blocks(randn(2, 2)))
61+
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
62+
```
63+
Also note the current definition errors in the limit
64+
when `blocks(a)` is empty, but even empty arrays generally
65+
have at least one block:
66+
```julia
67+
julia> length(blocks(randn(0)))
68+
1
69+
70+
julia> length(blocks(BlockVector{Float64}(randn(0))))
71+
1
72+
73+
julia> length(blocks(BlockedVector{Float64}(randn(0))))
74+
1
75+
```
76+
=#
4877
function blocktype(a::AbstractArray)
49-
# TODO: Unfortunately, this doesn't always give
50-
# a concrete type, even when it could be concrete, i.e.
51-
#=
52-
```julia
53-
julia> eltype(blocks(BlockArray(randn(2, 2), [1, 1], [1, 1])))
54-
Matrix{Float64} (alias for Array{Float64, 2})
55-
56-
julia> eltype(blocks(BlockedArray(randn(2, 2), [1, 1], [1, 1])))
57-
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
58-
59-
julia> eltype(blocks(randn(2, 2)))
60-
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
61-
```
62-
=#
6378
if isempty(blocks(a))
64-
return eltype(blocks(a))
79+
error("`blocktype` can't be determined if `isempty(blocks(a))`.")
6580
end
66-
return eltype(first(blocks(a)))
81+
return mapreduce(typeof, promote_type, blocks(a))
6782
end
6883

84+
using BlockArrays: BlockArray
85+
blockstype(::Type{<:BlockArray{<:Any,<:Any,B}}) where {B} = B
86+
blockstype(a::BlockArray) = blockstype(typeof(a))
87+
blocktype(arraytype::Type{<:BlockArray}) = eltype(blockstype(arraytype))
88+
blocktype(a::BlockArray) = eltype(blocks(a))
89+
6990
abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end
7091

7192
# TODO: Also support specifying the `blocktype` along with the `eltype`.
@@ -78,8 +99,6 @@ struct BlockSparseArrayInterface <: AbstractBlockSparseArrayInterface end
7899
@interface ::AbstractBlockSparseArrayInterface BlockArrays.blocks(a::AbstractArray) =
79100
error("Not implemented")
80101

81-
blockstype(a::AbstractArray) = blockstype(typeof(a))
82-
83102
@interface ::AbstractBlockSparseArrayInterface function Base.getindex(
84103
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
85104
) where {N}

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1818
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
1919
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

test/test_basics.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
using Adapt: adapt
22
using ArrayLayouts: zero!
33
using BlockArrays:
4+
BlockArrays,
45
Block,
6+
BlockArray,
57
BlockIndexRange,
68
BlockRange,
79
BlockSlice,
810
BlockVector,
911
BlockedOneTo,
1012
BlockedUnitRange,
13+
BlockedArray,
1114
BlockedVector,
1215
blockedrange,
1316
blocklength,
@@ -34,6 +37,7 @@ using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
3437
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
3538
using TensorAlgebra: contract
3639
using Test: @test, @test_broken, @test_throws, @testset, @inferred
40+
using TestExtras: @constinferred
3741
include("TestBlockSparseArraysUtils.jl")
3842

3943
arrayts = (Array, JLArray)
@@ -132,6 +136,35 @@ arrayts = (Array, JLArray)
132136
end
133137
end
134138
end
139+
@testset "blockstype, blocktype" begin
140+
a = arrayt(randn(elt, 2, 2))
141+
@test (@constinferred blockstype(a)) <: BlockArrays.BlocksView{elt,2}
142+
# TODO: This is difficult to determine just from type information.
143+
@test_broken blockstype(typeof(a)) <: BlockArrays.BlocksView{elt,2}
144+
@test (@constinferred blocktype(a)) <: SubArray{elt,2,arrayt{elt,2}}
145+
# TODO: This is difficult to determine just from type information.
146+
@test_broken blocktype(typeof(a)) <: SubArray{elt,2,arrayt{elt,2}}
147+
148+
a = BlockSparseMatrix{elt,arrayt{elt,2}}([1, 1], [1, 1])
149+
@test (@constinferred blockstype(a)) <: SparseMatrixDOK{arrayt{elt,2}}
150+
@test (@constinferred blockstype(typeof(a))) <: SparseMatrixDOK{arrayt{elt,2}}
151+
@test (@constinferred blocktype(a)) <: arrayt{elt,2}
152+
@test (@constinferred blocktype(typeof(a))) <: arrayt{elt,2}
153+
154+
a = BlockArray(arrayt(randn(elt, (2, 2))), [1, 1], [1, 1])
155+
@test (@constinferred blockstype(a)) === Matrix{arrayt{elt,2}}
156+
@test (@constinferred blockstype(typeof(a))) === Matrix{arrayt{elt,2}}
157+
@test (@constinferred blocktype(a)) <: arrayt{elt,2}
158+
@test (@constinferred blocktype(typeof(a))) <: arrayt{elt,2}
159+
160+
a = BlockedArray(arrayt(randn(elt, 2, 2)), [1, 1], [1, 1])
161+
@test (@constinferred blockstype(a)) <: BlockArrays.BlocksView{elt,2}
162+
# TODO: This is difficult to determine just from type information.
163+
@test_broken blockstype(typeof(a)) <: BlockArrays.BlocksView{elt,2}
164+
@test (@constinferred blocktype(a)) <: SubArray{elt,2,arrayt{elt,2}}
165+
# TODO: This is difficult to determine just from type information.
166+
@test_broken blocktype(typeof(a)) <: SubArray{elt,2,arrayt{elt,2}}
167+
end
135168
@testset "Basics" begin
136169
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
137170
@allowscalar @test a == dev(

0 commit comments

Comments
 (0)