1
+ using BlockArrays: BlockRange, blockisequal
1
2
using DerivableInterfaces: @interface , AbstractArrayInterface, interface
2
3
using GPUArraysCore: @allowscalar
3
4
5
+ function map_blockwise! (f, a_dest:: AbstractArray , a_srcs:: AbstractArray... )
6
+ # TODO : This assumes element types are numbers, generalize this logic.
7
+ f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
8
+ Is = if f_preserves_zeros
9
+ ∪ (map (eachblockstoredindex, (a_dest, a_srcs... ))... )
10
+ else
11
+ BlockRange (a_dest)
12
+ end
13
+ for I in Is
14
+ map! (f, view (a_dest, I), map (Base. Fix2 (view, I), a_srcs)... )
15
+ end
16
+ return a_dest
17
+ end
18
+
4
19
# TODO : Rewrite this so that it takes the blocking structure
5
20
# made by combining the blocking of the axes (i.e. the blocking that
6
21
# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +31,16 @@ using GPUArraysCore: @allowscalar
16
31
@interface interface map_zero_dim! (f, a_dest, a_srcs... )
17
32
return a_dest
18
33
end
34
+ blockwise = all (
35
+ ntuple (ndims (a_dest)) do dim
36
+ ax = map (Base. Fix2 (axes, dim), (a_dest, a_srcs... ))
37
+ return blockisequal (ax... )
38
+ end ,
39
+ )
40
+ if blockwise
41
+ map_blockwise! (f, a_dest, a_srcs... )
42
+ return a_dest
43
+ end
19
44
# TODO : This assumes element types are numbers, generalize this logic.
20
45
f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
21
46
a_dest, a_srcs = reblock (a_dest), reblock .(a_srcs)
0 commit comments