Skip to content

Commit d360f75

Browse files
authored
Noncontiguous slicing (#116)
1 parent 9eb742b commit d360f75

File tree

7 files changed

+143
-15
lines changed

7 files changed

+143
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ using SparseArraysBase:
3030

3131
# A return type for `blocks(array)` when `array` isn't blocked.
3232
# Represents a vector with just that single block.
33-
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
33+
struct SingleBlockView{N,Array<:AbstractArray{<:Any,N}} <: AbstractArray{Array,N}
3434
array::Array
3535
end
3636
Base.parent(a::SingleBlockView) = a.array
37+
Base.size(a::SingleBlockView) = ntuple(Returns(1), ndims(a))
3738
blocks_maybe_single(a) = blocks(a)
3839
blocks_maybe_single(a::Array) = SingleBlockView(a)
39-
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
40+
function Base.getindex(a::SingleBlockView{N}, index::Vararg{Int,N}) where {N}
4041
@assert all(isone, index)
4142
return parent(a)
4243
end
@@ -357,7 +358,11 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
357358
end
358359

359360
function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
360-
return Block(1):Block(1)
361+
return Block.(Base.OneTo(1))
362+
end
363+
364+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
365+
return Block.(Base.OneTo(1))
361366
end
362367

363368
function blockrange(axis::AbstractUnitRange, r)

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
2727
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
2828
include("abstractblocksparsearray/abstractblocksparsevector.jl")
2929
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
30+
include("abstractblocksparsearray/unblockedsubarray.jl")
3031
include("abstractblocksparsearray/views.jl")
3132
include("abstractblocksparsearray/arraylayouts.jl")
3233
include("abstractblocksparsearray/sparsearrayinterface.jl")

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,20 @@ function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, ax
4343
return a_dest
4444
end
4545

46+
function _similar(arraytype::Type{<:AbstractArray}, size::Tuple)
47+
return similar(arraytype, size)
48+
end
49+
function _similar(
50+
::Type{<:SubArray{<:Any,<:Any,<:ArrayType}}, size::Tuple
51+
) where {ArrayType}
52+
return similar(ArrayType, size)
53+
end
54+
4655
# Materialize a SubArray view.
4756
function ArrayLayouts.sub_materialize(
4857
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
4958
)
50-
a_dest = blocktype(a)(undef, length.(axes))
59+
a_dest = _similar(blocktype(a), length.(axes))
5160
a_dest .= a
5261
return a_dest
5362
end
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout
2+
using Base.Broadcast: Broadcast, BroadcastStyle
3+
using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
4+
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype
5+
6+
const UnblockedIndices = Union{
7+
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
8+
}
9+
10+
const UnblockedSubArray{T,N} = SubArray{
11+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{UnblockedIndices}}
12+
}
13+
14+
function BlockArrays.blocks(a::UnblockedSubArray)
15+
return SingleBlockView(a)
16+
end
17+
18+
function DerivableInterfaces.interface(arraytype::Type{<:UnblockedSubArray})
19+
return interface(blocktype(parenttype(arraytype)))
20+
end
21+
22+
function ArrayLayouts.MemoryLayout(arraytype::Type{<:UnblockedSubArray})
23+
return MemoryLayout(blocktype(parenttype(arraytype)))
24+
end
25+
26+
function Broadcast.BroadcastStyle(arraytype::Type{<:UnblockedSubArray})
27+
return BroadcastStyle(blocktype(parenttype(arraytype)))
28+
end
29+
30+
function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}, elt::Type)
31+
return similartype(blocktype(parenttype(arraytype)), elt)
32+
end
33+
34+
function Base.similar(
35+
a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
36+
)
37+
return similar(similartype(blocktype(parenttype(a)), elt), axes)
38+
end
39+
function Base.similar(a::UnblockedSubArray, elt::Type, size::Tuple{Int,Vararg{Int}})
40+
return similar(a, elt, Base.OneTo.(size))
41+
end
42+
43+
function ArrayLayouts.sub_materialize(a::UnblockedSubArray)
44+
a_cpu = adapt(Array, a)
45+
a_cpu′ = similar(a_cpu)
46+
a_cpu′ .= a_cpu
47+
if typeof(a) === typeof(a_cpu)
48+
return a_cpu′
49+
end
50+
a′ = similar(a)
51+
a′ .= a_cpu′
52+
return a′
53+
end
54+
55+
function Base.map!(
56+
f, a_dest::AbstractArray, a_src1::UnblockedSubArray, a_src_rest::UnblockedSubArray...
57+
)
58+
return invoke(
59+
map!,
60+
Tuple{Any,AbstractArray,AbstractArray,Vararg{AbstractArray}},
61+
f,
62+
a_dest,
63+
a_src1,
64+
a_src_rest...,
65+
)
66+
end
67+
68+
# Fix ambiguity and scalar indexing errors with GPUArrays.
69+
using Adapt: adapt
70+
using GPUArraysCore: GPUArraysCore
71+
function Base.map!(
72+
f,
73+
a_dest::GPUArraysCore.AnyGPUArray,
74+
a_src1::UnblockedSubArray,
75+
a_src_rest::UnblockedSubArray...,
76+
)
77+
a_dest_cpu = adapt(Array, a_dest)
78+
a_srcs_cpu = map(adapt(Array), (a_src1, a_src_rest...))
79+
map!(f, a_dest_cpu, a_srcs_cpu...)
80+
a_dest .= a_dest_cpu
81+
return a_dest
82+
end
83+
84+
function Base.iszero(a::UnblockedSubArray)
85+
return invoke(iszero, Tuple{AbstractArray}, adapt(Array, a))
86+
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,19 @@ end
364364
function Base.size(a::SparseSubArrayBlocks)
365365
return length.(axes(a))
366366
end
367-
# TODO: Define `isstored`.
367+
368+
# TODO: Make a faster version for when the slice is blockwise.
369+
function SparseArraysBase.isstored(
370+
a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}
371+
) where {N}
372+
J = Base.reindex(parentindices(a.array), to_indices(a.array, Block.(I)))
373+
# TODO: Try doing this blockwise when possible rather
374+
# than elementwise.
375+
return any(Iterators.product(J...)) do K
376+
return isstored(parent(a.array), K...)
377+
end
378+
end
379+
368380
# TODO: Define `getstoredindex`, `getunstoredindex` instead.
369381
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
370382
# TODO: Should this be defined as `@view a.array[Block(I)]` instead?
@@ -400,9 +412,17 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
400412
# TODO: Implement this properly.
401413
return true
402414
end
403-
function SparseArraysBase.eachstoredindex(a::SparseSubArrayBlocks)
404-
return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
415+
416+
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::SparseSubArrayBlocks)
417+
return filter(eachindex(a)) do I
418+
return isstored(a, I)
419+
end
420+
421+
## # TODO: This only works for blockwise slices, i.e. slices using
422+
## # `BlockSliceCollection`.
423+
## return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
405424
end
425+
406426
# TODO: Either make this the generic interface or define
407427
# `SparseArraysBase.sparse_storage`, which is used
408428
# to defined this.

test/test_basics.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ arrayts = (Array, JLArray)
5151
a[Block(2, 2)] = dev(randn(elt, 3, 3))
5252
@test_broken a[:, 4]
5353

54-
# TODO: Fix this and turn it into a proper test.
55-
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
56-
a[Block(1, 1)] = dev(randn(elt, 2, 2))
57-
a[Block(2, 2)] = dev(randn(elt, 3, 3))
58-
@test_broken a[:, [2, 4]]
59-
@test_broken a[[3, 5], [2, 4]]
60-
6154
# TODO: Fix this and turn it into a proper test.
6255
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
6356
a[Block(1, 1)] = dev(randn(elt, 2, 2))
@@ -713,6 +706,20 @@ arrayts = (Array, JLArray)
713706
@test a[Block(2, 2)[1:2, 2:3]] == b
714707
@test blockstoredlength(a) == 1
715708

709+
# Noncontiguous slicing.
710+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
711+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
712+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
713+
I = ([3, 5], [2, 4])
714+
@test Array(a[I...]) == Array(a)[I...]
715+
716+
# Noncontiguous slicing.
717+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
718+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
719+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
720+
I = (:, [2, 4])
721+
@test Array(a[I...]) == Array(a)[I...]
722+
716723
a = BlockSparseArray{elt}(undef, [2, 3], [2, 3])
717724
@views for b in [Block(1, 1), Block(2, 2)]
718725
a[b] = randn(elt, size(a[b]))

0 commit comments

Comments
 (0)