Skip to content

Commit bfd0a3e

Browse files
committed
introduce runtime representation of broadcast fusion
fix #21094 fix #22060 fix #22053 replaces #22063
1 parent fce0a3c commit bfd0a3e

File tree

6 files changed

+343
-129
lines changed

6 files changed

+343
-129
lines changed

base/broadcast.jl

+225
Original file line numberDiff line numberDiff line change
@@ -609,4 +609,229 @@ macro __dot__(x)
609609
esc(__dot__(x))
610610
end
611611

612+
############################################################
613+
## The parser turns dotted calls into the equivalent Fusion expression.
614+
## Effectively, this turns the Expr tree into a runtime AST,
615+
## for a limited subset of expression types.
616+
#
617+
## For example, in the expression:
618+
# d = sin.((a .+ (b .* c))...)
619+
## The kernel becomes
620+
# d' = Fusion{3}(
621+
# FusionApply(
622+
# sin,
623+
# ( FusionCall(
624+
# +,
625+
# ( FusionArg{1}(),
626+
# FusionCall(
627+
# *,
628+
# ( FusionArg{2}(),
629+
# FusionArg{3}() )), )), )),
630+
# (:a, :b, :c))
631+
## and then the final expansion becomes:
632+
# d = broadcast(d', a, b, c)
633+
634+
struct Fusion{N, vararg#=::Bool=#, T}
635+
f::T
636+
# Debugging Metadata:
637+
# names::NTuple{N, Symbol}
638+
# source::LineNumberNode
639+
function Fusion{N, vararg}(f) where {N, vararg}
640+
return new{N, vararg::Bool, typeof(f)}(f)
641+
end
642+
end
643+
644+
struct FusionArg{N}
645+
end
646+
647+
struct FusionConstant{T}
648+
c::T
649+
function FusionConstant(c) where {}
650+
return new{typeof(c)}(c)
651+
end
652+
end
653+
654+
struct FusionCall{F, Args<:Tuple}
655+
f::F
656+
args::Args
657+
function FusionCall(f, args::Tuple) where {}
658+
return new{typeof(f), typeof(args)}(f, args)
659+
end
660+
end
661+
662+
struct FusionApply{N, F, Args<:NTuple{N, Any}}
663+
f::F
664+
args::Args
665+
function FusionApply(f, args::NTuple{N, Any}) where {N}
666+
return new{N, typeof(f), typeof(args)}(f, args)
667+
end
668+
end
669+
670+
function kw_to_vec(kws::Vector{Any})
671+
kwargs = Vector{Any}(2 * length(kws))
672+
for i in 1:2:length(kws)
673+
kw = kws[i]::Tuple{Any, Any}
674+
kwargs[i] = getfield(kw, 1)
675+
kwargs[i + 1] = getfield(kw, 2)
676+
end
677+
return kwargs
678+
end
679+
680+
struct FusionKWCall{F, Args<:Tuple}
681+
f::F
682+
args::Args
683+
kwargs::Vector{Any}
684+
function FusionKWCall(f, args::Tuple; kwargs...) where {}
685+
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
686+
end
687+
end
688+
689+
struct FusionKWApply{F, Args<:Tuple}
690+
f::F
691+
args::Args
692+
kwargs::Vector{Any}
693+
function FusionKWApply(f, args::Tuple; kwargs...) where {}
694+
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
695+
end
696+
end
697+
698+
function tuplehead(t::Tuple, N::Val)
699+
return ntuple(i -> t[i], N)
700+
end
701+
@generated function tupletail(t::NTuple{M, Any}, ::Val{N}) where {N, M}
702+
# alternative, non-generated versions,
703+
# enable when inference is improved:
704+
#tupletail(t, Nreq) = ntuple(i -> t[i + Nreq], length(t) - Nreq)
705+
#tupletail(t, Nreq) = t[(Nreq + 1):end]
706+
args = Any[ :(getfield(t, $i)) for i in (N + 1):M ]
707+
tpl = Expr(:tuple)
708+
tpl.args = args
709+
return tpl
710+
end
711+
712+
@inline (f::Fusion{N, false})(args::Vararg{Any, N}) where {N} = f.f(args...)
713+
function (f::Fusion{Nreq, true})(args::Vararg{Any, M}) where {Nreq, M}
714+
M >= Nreq || throw(MethodError(f, args))
715+
fargs = tuplehead(args, Val(Nreq))
716+
vararg = tupletail(args, Val(Nreq))
717+
return f.f((fargs..., vararg)...)
718+
end
719+
@inline (f::FusionArg{N})(args...) where {N} = args[N]
720+
@inline (f::FusionConstant)(args...) = f.c
721+
@inline (f::FusionCall)(args...) = f.f(map(a -> a(args...), f.args)...)
722+
# TODO: calling _apply on map _apply is not handled by inference
723+
# for now, we unroll some cases and generate others, to help it out
724+
#@inline (f::FusionApply)(args...) = Core._apply(f.f, map(a -> a(args...), f.args)...)
725+
@inline (f::FusionApply{0})(args...) = f.f()
726+
@inline (f::FusionApply{1})(args...) = f.f(f.args[1](args...)...)
727+
@inline (f::FusionApply{2})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)...)
728+
@inline (f::FusionApply{3})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)..., f.args[3](args...)...)
729+
@generated function (f::FusionApply{N})(args...) where {N}
730+
fargs = Any[ :(getfield(f.args, $i)(args...)) for i in 1:N ]
731+
return Expr(:call, GlobalRef(Core, :_apply), :(f.f), fargs...)
732+
end
733+
@inline function (f::FusionKWCall)(args...)
734+
fargs = map(a -> a(args...), f.args)
735+
# return f.f(args...; kwargs...)
736+
if isempty(f.kwargs)
737+
return f.f(fargs...)
738+
else
739+
return Core.kwfunc(f.f)(f.kwargs, f.f, fargs...)
740+
end
741+
end
742+
@inline function (f::FusionKWApply)(args...)
743+
fargs = map(a -> a(args...), f.args)
744+
# return Core._apply(f.f, args...; kwargs...)
745+
if isempty(f.kwargs)
746+
return Core._apply(f.f, fargs...)
747+
else
748+
return Core._apply(Core.kwfunc(f.f), (f.kwargs,), (f.f,), fargs...)
749+
end
750+
end
751+
752+
function Base.show(io::IO, f::Fusion{N, vararg}) where {N, vararg}
753+
nargs = (vararg ? N + 1 : N)
754+
names = String[ "a_$i" for i in 1:nargs ] # f.names
755+
print(io, "(")
756+
join(io, names, ", ")
757+
vararg && print(io, "...")
758+
print(io, ") -> ")
759+
show_fusion(io, f.f, names)
760+
end
761+
762+
function show_fusion(io::IO, f::FusionArg{N}, names) where N
763+
print(io, names[N])
764+
nothing
765+
end
766+
767+
function show_fusion(io::IO, f::FusionConstant{N}, names) where N
768+
print(io, f.c)
769+
nothing
770+
end
771+
772+
function show_fusion(io::IO, f::FusionCall, names)
773+
Base.show(io, f.f)
774+
print(io, '(')
775+
first = true
776+
for i in f.args
777+
first || print(io, ", ")
778+
first = false
779+
show_fusion(io, i, names)
780+
end
781+
print(io, ')')
782+
nothing
783+
end
784+
785+
function show_fusion(io::IO, f::FusionApply, names)
786+
print(io, "Core._apply(")
787+
Base.show(io, f.f)
788+
for i in f.args
789+
print(io, ", ")
790+
show_fusion(io, i, names)
791+
end
792+
print(io, ')')
793+
nothing
794+
end
795+
796+
function show_fusion(io::IO, f::FusionKWCall, names)
797+
Base.show(io, f.f)
798+
print(io, '(')
799+
first = true
800+
for i in f.args
801+
first || print(io, ", ")
802+
first = false
803+
show_fusion(io, i, names)
804+
end
805+
print(io, "; ")
806+
first = true
807+
for i in 1:2:length(f.kwargs)
808+
first || print(io, ", ")
809+
first = false
810+
print(io, f.kwargs[i])
811+
print(io, "=")
812+
end
813+
print(io, ')')
814+
nothing
815+
end
816+
817+
818+
function show_fusion(io::IO, f::FusionKWApply, names)
819+
print(io, "Core._apply(")
820+
Base.show(io, f.f)
821+
for i in f.args
822+
print(io, ", ")
823+
show_fusion(io, i, names)
824+
end
825+
print(io, "; #=kwargs=#...)")
826+
nothing
827+
end
828+
829+
830+
function show_fusion(io::IO, @nospecialize(f), names)
831+
print(io, "#= unexpected expression ")
832+
show(io, f)
833+
print(io, " =#")
834+
nothing
835+
end
836+
612837
end # module

base/inference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Core: _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstanc
44

55
#### parameters limiting potentially-infinite types ####
66
const MAX_TYPEUNION_LEN = 3
7-
const MAX_TYPE_DEPTH = 8
7+
const MAX_TYPE_DEPTH = 10
88
const TUPLE_COMPLEXITY_LIMIT_DEPTH = 3
99

1010
const MAX_INLINE_CONST_SIZE = 256

0 commit comments

Comments
 (0)