Skip to content

Commit 10f6ebf

Browse files
authored
Fix adapt, fix printing subarrays (#82)
1 parent 1f833df commit 10f6ebf

File tree

5 files changed

+69
-16
lines changed

5 files changed

+69
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/views.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,21 @@ function Base.view(
130130
) where {T,N}
131131
return viewblock(a, block)
132132
end
133+
134+
# Fix ambiguity error with BlockArrays.jl for slices like
135+
# `a = BlockSparseArray{Float64}(undef, [2, 2], [2, 2]); @view a[:, :]`.
136+
function Base.view(
137+
a::SubArray{
138+
T,
139+
N,
140+
<:AbstractBlockSparseArray{T,N},
141+
<:Tuple{Vararg{Union{Base.Slice,BlockSlice{<:BlockRange{1}}},N}},
142+
},
143+
block::Block{N},
144+
) where {T,N}
145+
return viewblock(a, block)
146+
end
147+
133148
function Base.view(
134149
a::SubArray{
135150
T,

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,25 @@ function Base.replace_in_print_matrix(
361361
end
362362

363363
# attempt to catch things that wrap GPU arrays
364-
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
365-
X_cpu = adapt(Array, X)
366-
if typeof(X_cpu) === typeof(X) # prevent infinite recursion
364+
function Base.print_array(io::IO, a::AnyAbstractBlockSparseArray)
365+
a_cpu = adapt(Array, a)
366+
if typeof(a_cpu) === typeof(a) # prevent infinite recursion
367367
# need to specify ndims to allow specialized code for vector/matrix
368368
@allowscalar @invoke Base.print_array(
369-
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
369+
io, a_cpu::AbstractArray{eltype(a_cpu),ndims(a_cpu)}
370370
)
371-
else
372-
Base.print_array(io, X_cpu)
371+
return nothing
373372
end
373+
Base.print_array(io, a_cpu)
374+
return nothing
375+
end
376+
377+
using Adapt: Adapt, adapt
378+
function Adapt.adapt_structure(to, a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray})
379+
# In the generic definition in Adapt.jl, `parentindices(a)` are also
380+
# adapted, but is broken when the parent indices contained blocked unit
381+
# ranges since `adapt` is broken on blocked unit ranges.
382+
# TODO: Fix adapt for blocked unit ranges by making an AdaptExt for
383+
# BlockArrays.jl.
384+
return SubArray(adapt(to, parent(a)), parentindices(a))
374385
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using BlockArrays:
1313
blockcheckbounds,
1414
blockisequal,
1515
blocklengths,
16+
blocklength,
1617
blocks,
1718
findblockindex
1819
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
@@ -412,6 +413,7 @@ end
412413

413414
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
414415
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
416+
to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices))
415417

416418
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(
417419
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}}

test/test_basics.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using BlockSparseArrays:
3434
sparsemortar,
3535
view!
3636
using GPUArraysCore: @allowscalar
37-
using JLArrays: JLArray
37+
using JLArrays: JLArray, JLMatrix
3838
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
3939
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
4040
using TensorAlgebra: contract
@@ -306,6 +306,27 @@ arrayts = (Array, JLArray)
306306
@test @views(at[Block(1, 2)]) isa Adjoint
307307
end
308308
end
309+
@testset "adapt" begin
310+
a = BlockSparseArray{elt}(undef, [2, 2], [2, 2])
311+
a_12 = randn(elt, 2, 2)
312+
a[Block(1, 2)] = a_12
313+
a_jl = adapt(JLArray, a)
314+
@test a_jl isa BlockSparseMatrix{elt,JLMatrix{elt}}
315+
@test blocktype(a_jl) == JLMatrix{elt}
316+
@test blockstoredlength(a_jl) == 1
317+
@test a_jl[Block(1, 2)] isa JLMatrix{elt}
318+
@test adapt(Array, a_jl[Block(1, 2)]) == a_12
319+
320+
a = BlockSparseArray{elt}(undef, [2, 2], [2, 2])
321+
a_12 = randn(elt, 2, 2)
322+
a[Block(1, 2)] = a_12
323+
a_jl = adapt(JLArray, @view(a[:, :]))
324+
@test a_jl isa SubArray{elt,2,<:BlockSparseMatrix{elt,JLMatrix{elt}}}
325+
@test blocktype(a_jl) == JLMatrix{elt}
326+
@test blockstoredlength(a_jl) == 1
327+
@test a_jl[Block(1, 2)] isa JLMatrix{elt}
328+
@test adapt(Array, a_jl[Block(1, 2)]) == a_12
329+
end
309330
@testset "Tensor algebra" begin
310331
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
311332
@views for b in [Block(1, 2), Block(2, 1)]
@@ -1149,15 +1170,19 @@ arrayts = (Array, JLArray)
11491170
# Not testing other element types since they change the
11501171
# spacing so it isn't easy to make the test general.
11511172

1152-
a = BlockSparseMatrix{elt,arrayt{elt,2}}(undef, [2, 2], [2, 2])
1153-
@allowscalar a[1, 2] = 12
1154-
@test sprint(show, "text/plain", a) ==
1155-
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ "
1173+
a′ = BlockSparseMatrix{elt,arrayt{elt,2}}(undef, [2, 2], [2, 2])
1174+
@allowscalar a′[1, 2] = 12
1175+
for a in (a′, @view(a′[:, :]))
1176+
@test sprint(show, "text/plain", a) ==
1177+
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ "
1178+
end
11561179

1157-
a = BlockSparseArray{elt,3,arrayt{elt,3}}(undef, [2, 2], [2, 2], [2, 2])
1158-
@allowscalar a[1, 2, 1] = 121
1159-
@test sprint(show, "text/plain", a) ==
1160-
"$(summary(a)):\n[:, :, 1] =\n $(zero(eltype(a))) $(eltype(a)(121)) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 2] =\n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ "
1180+
a′ = BlockSparseArray{elt,3,arrayt{elt,3}}(undef, [2, 2], [2, 2], [2, 2])
1181+
@allowscalar a′[1, 2, 1] = 121
1182+
for a in (a′, @view(a′[:, :, :]))
1183+
@test sprint(show, "text/plain", a) ==
1184+
"$(summary(a)):\n[:, :, 1] =\n $(zero(eltype(a))) $(eltype(a)(121)) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 2] =\n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ "
1185+
end
11611186
end
11621187
end
11631188
@testset "TypeParameterAccessors.position" begin

0 commit comments

Comments
 (0)