Skip to content

Commit f985875

Browse files
committed
Define simpler codepath for blockwise map
1 parent d360f75 commit f985875

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.2"
4+
version = "0.5.3"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/blocksparsearrayinterface/map.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
using BlockArrays: BlockRange, blockisequal
12
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
23
using GPUArraysCore: @allowscalar
34

5+
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
6+
# TODO: This assumes element types are numbers, generalize this logic.
7+
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
8+
Is = if f_preserves_zeros
9+
(map(eachblockstoredindex, (a_dest, a_srcs...))...)
10+
else
11+
BlockRange(a_dest)
12+
end
13+
for I in Is
14+
map!(f, view(a_dest, I), map(Base.Fix2(view, I), a_srcs)...)
15+
end
16+
return a_dest
17+
end
18+
419
# TODO: Rewrite this so that it takes the blocking structure
520
# made by combining the blocking of the axes (i.e. the blocking that
621
# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +31,16 @@ using GPUArraysCore: @allowscalar
1631
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
1732
return a_dest
1833
end
34+
blockwise = all(
35+
ntuple(ndims(a_dest)) do dim
36+
ax = map(Base.Fix2(axes, dim), (a_dest, a_srcs...))
37+
return blockisequal(ax...)
38+
end,
39+
)
40+
if blockwise
41+
map_blockwise!(f, a_dest, a_srcs...)
42+
return a_dest
43+
end
1944
# TODO: This assumes element types are numbers, generalize this logic.
2045
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
2146
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)

0 commit comments

Comments
 (0)