Skip to content

Commit

Permalink
Fix sum(bc::Broadcasted; dims = 1, init = 0) (#43736)
Browse files Browse the repository at this point in the history
This PR make `has_fast_linear_indexing` rely on `IndexStyle`/`ndims` to
fix `mapreduce` for `Broadcasted` with `dim > 1`.
Before:
```julia
julia> a = randn(100,100);

julia> bc = Broadcast.instantiate(Base.broadcasted(+,a,a));

julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1)
ERROR: MethodError: no method matching LinearIndices(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Matrix{Float64}, Matrix{Float64}}})
```
After:
```julia
julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1)
true
```

This should extend the optimized fallback to more `AbstractArray`. (e.g.
`SubArray`)

Test added.
  • Loading branch information
N5N3 authored Jul 29, 2024
1 parent 20b3af3 commit f979ee9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 2 additions & 5 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,8 @@ end

## generic (map)reduction

has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = false
has_fast_linear_indexing(a::Array) = true
has_fast_linear_indexing(::Union{Number,Ref,AbstractChar}) = true # 0d objects, for Broadcasted
has_fast_linear_indexing(bc::Broadcast.Broadcasted) =
all(has_fast_linear_indexing, bc.args)
has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = IndexStyle(a) === IndexLinear()
has_fast_linear_indexing(a::AbstractVector) = true

function check_reducedims(R, A)
# Check whether R has compatible dimensions w.r.t. A for reduction
Expand Down
4 changes: 4 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,10 @@ end
@test sum(bc, dims=1, init=0) == [5]
bc = Broadcast.instantiate(Broadcast.broadcasted(*, ['a','b'], 'c'))
@test prod(bc, dims=1, init="") == ["acbc"]

a = rand(-10:10,32,4); b = rand(-10:10,32,4)
bc = Broadcast.instantiate(Broadcast.broadcasted(+,a,b))
@test sum(bc; dims = 1, init = 0.0) == sum(collect(bc); dims = 1, init = 0.0)
end

# treat Pair as scalar:
Expand Down

0 comments on commit f979ee9

Please sign in to comment.