Skip to content

Commit 6c05eb9

Browse files
authored
Define simpler codepath for blockwise map (#118)
1 parent d360f75 commit 6c05eb9

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.2"
4+
version = "0.5.3"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,6 @@ function blockrange(axis::AbstractUnitRange, r::Int)
290290
return error("Slicing with integer values isn't supported.")
291291
end
292292

293-
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
294-
for b in r
295-
@assert b blockaxes(axis, 1)
296-
end
297-
return r
298-
end
299-
300293
# This handles changing the blocking, for example:
301294
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
302295
# I = blockedrange([4, 4])
@@ -315,13 +308,20 @@ end
315308
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
316309
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
317310
# a[I, I]
318-
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
311+
function blockrange(axis::AbstractUnitRange, r::AbstractBlockVector{<:Block{1}})
319312
for b in r
320313
@assert b blockaxes(axis, 1)
321314
end
322315
return only(blockaxes(r))
323316
end
324317

318+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
319+
for b in r
320+
@assert b blockaxes(axis, 1)
321+
end
322+
return r
323+
end
324+
325325
using BlockArrays: BlockSlice
326326
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
327327
return blockrange(axis, r.block)

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ end
445445

446446
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
447447
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
448-
to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices))
448+
to_blocks_indices(I::Base.Slice) = Base.OneTo(blocklength(I.indices))
449449

450450
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(
451451
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}}

src/blocksparsearrayinterface/map.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,52 @@
1+
using BlockArrays: BlockRange, blockisequal
12
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
23
using GPUArraysCore: @allowscalar
34

5+
# Check if the block structures are the same.
6+
function same_block_structure(as::AbstractArray...)
7+
isempty(as) && return true
8+
return all(
9+
ntuple(ndims(first(as))) do dim
10+
ax = map(Base.Fix2(axes, dim), as)
11+
return blockisequal(ax...)
12+
end,
13+
)
14+
end
15+
16+
# Find the common stored blocks, assuming the block structures are the same.
17+
function union_eachblockstoredindex(as::AbstractArray...)
18+
return (map(eachblockstoredindex, as)...)
19+
end
20+
21+
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
22+
# TODO: This assumes element types are numbers, generalize this logic.
23+
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
24+
Is = if f_preserves_zeros
25+
union_eachblockstoredindex(a_dest, a_srcs...)
26+
else
27+
BlockRange(a_dest)
28+
end
29+
for I in Is
30+
# TODO: Use:
31+
# block_dest = @view a_dest[I]
32+
# or:
33+
# block_dest = @view! a_dest[I]
34+
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
35+
# TODO: Use:
36+
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37+
block_srcs = map(a_srcs) do a_src
38+
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
39+
end
40+
# TODO: Use `map!!` to handle immutable blocks.
41+
map!(f, block_dest, block_srcs...)
42+
# Replace the entire block, handles initializing new blocks
43+
# or if blocks are immutable.
44+
# TODO: Use `a_dest[I] = block_dest`.
45+
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
46+
end
47+
return a_dest
48+
end
49+
450
# TODO: Rewrite this so that it takes the blocking structure
551
# made by combining the blocking of the axes (i.e. the blocking that
652
# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +62,10 @@ using GPUArraysCore: @allowscalar
1662
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
1763
return a_dest
1864
end
65+
if same_block_structure(a_dest, a_srcs...)
66+
map_blockwise!(f, a_dest, a_srcs...)
67+
return a_dest
68+
end
1969
# TODO: This assumes element types are numbers, generalize this logic.
2070
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
2171
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)

0 commit comments

Comments
 (0)