Skip to content

Commit ae6cdd1

Browse files
committed
Fix JuliaLang#21291, type-stabilize broadcast over tuples and scalars
1 parent d9771af commit ae6cdd1

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

base/broadcast.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,20 +311,17 @@ end
311311
end
312312
end
313313
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
314-
broadcast_c(f, ::Type{Tuple}, As...) =
315-
ntuple(k -> f(_tuplebroadcast_getargs(As, k)...), _tuplebroadcast_reslength(As))
316-
broadcast_c{T}(f, ::Type{Tuple}, ::Type{T}, As...) =
317-
ntuple(k -> f(T, _tuplebroadcast_getargs(As, k)...), _tuplebroadcast_reslength(As))
318-
@inline _tuplebroadcast_getargs(::Tuple{}, k) = ()
319-
@inline _tuplebroadcast_getargs(As, k) =
320-
(_broadcast_getindex(first(As), k), _tuplebroadcast_getargs(tail(As), k)...)
321-
@noinline _tuplebroadcast_reslength(As) =
322-
_tuplebroadcast_maxlength(_tuplebroadcast_length(first(As)), tail(As))
323-
@inline _tuplebroadcast_maxlength(l, As) =
324-
_tuplebroadcast_maxlength(max(l, _tuplebroadcast_length(first(As))), tail(As))
325-
@inline _tuplebroadcast_maxlength(l, ::Tuple{}) = l
326-
@inline _tuplebroadcast_length(t::Tuple) = length(t)
327-
@inline _tuplebroadcast_length(s) = 1
314+
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
315+
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
316+
@inline tuplebroadcast{N}(f, ::NTuple{N,Any}, As...) =
317+
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val{N})
318+
@inline tuplebroadcast{N,T}(f, ::NTuple{N,Any}, ::Type{T}, As...) =
319+
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val{N})
320+
first_tuple(A::Tuple, Bs...) = A
321+
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
322+
tuplebroadcast_getargs(::Tuple{}, k) = ()
323+
@inline tuplebroadcast_getargs(As, k) =
324+
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)
328325

329326
"""
330327
broadcast(f, As...)

test/broadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,9 @@ end
494494
A[1:3,1:3] .= [ones(2,2)]
495495
@test all(A[1:3,1:3] .== [ones(2,2)])
496496
end
497+
498+
# Issue #21291
499+
let t = (0, 1, 2)
500+
o = 1
501+
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
502+
end

0 commit comments

Comments
 (0)