@@ -609,4 +609,229 @@ macro __dot__(x)
609
609
esc (__dot__ (x))
610
610
end
611
611
612
+ # ###########################################################
613
+ # # The parser turns dotted calls into the equivalent Fusion expression.
614
+ # # Effectively, this turns the Expr tree into a runtime AST,
615
+ # # for a limited subset of expression types.
616
+ #
617
+ # # For example, in the expression:
618
+ # d = sin.((a .+ (b .* c))...)
619
+ # # The kernel becomes
620
+ # d' = Fusion{3}(
621
+ # FusionApply(
622
+ # sin,
623
+ # ( FusionCall(
624
+ # +,
625
+ # ( FusionArg{1}(),
626
+ # FusionCall(
627
+ # *,
628
+ # ( FusionArg{2}(),
629
+ # FusionArg{3}() )), )), )),
630
+ # (:a, :b, :c))
631
+ # # and then the final expansion becomes:
632
+ # d = broadcast(d', a, b, c)
633
+
634
+ struct Fusion{N, vararg#= ::Bool=# , T}
635
+ f:: T
636
+ # Debugging Metadata:
637
+ # names::NTuple{N, Symbol}
638
+ # source::LineNumberNode
639
+ function Fusion {N, vararg} (f) where {N, vararg}
640
+ return new {N, vararg::Bool, typeof(f)} (f)
641
+ end
642
+ end
643
+
644
+ struct FusionArg{N}
645
+ end
646
+
647
+ struct FusionConstant{T}
648
+ c:: T
649
+ function FusionConstant (c) where {}
650
+ return new {typeof(c)} (c)
651
+ end
652
+ end
653
+
654
+ struct FusionCall{F, Args<: Tuple }
655
+ f:: F
656
+ args:: Args
657
+ function FusionCall (f, args:: Tuple ) where {}
658
+ return new {typeof(f), typeof(args)} (f, args)
659
+ end
660
+ end
661
+
662
+ struct FusionApply{N, F, Args<: NTuple{N, Any} }
663
+ f:: F
664
+ args:: Args
665
+ function FusionApply (f, args:: NTuple{N, Any} ) where {N}
666
+ return new {N, typeof(f), typeof(args)} (f, args)
667
+ end
668
+ end
669
+
670
+ function kw_to_vec (kws:: Vector{Any} )
671
+ kwargs = Vector {Any} (2 * length (kws))
672
+ for i in 1 : 2 : length (kws)
673
+ kw = kws[i]:: Tuple{Any, Any}
674
+ kwargs[i] = getfield (kw, 1 )
675
+ kwargs[i + 1 ] = getfield (kw, 2 )
676
+ end
677
+ return kwargs
678
+ end
679
+
680
+ struct FusionKWCall{F, Args<: Tuple }
681
+ f:: F
682
+ args:: Args
683
+ kwargs:: Vector{Any}
684
+ function FusionKWCall (f, args:: Tuple ; kwargs... ) where {}
685
+ return new {typeof(f), typeof(args)} (f, args, kw_to_vec (kwargs))
686
+ end
687
+ end
688
+
689
+ struct FusionKWApply{F, Args<: Tuple }
690
+ f:: F
691
+ args:: Args
692
+ kwargs:: Vector{Any}
693
+ function FusionKWApply (f, args:: Tuple ; kwargs... ) where {}
694
+ return new {typeof(f), typeof(args)} (f, args, kw_to_vec (kwargs))
695
+ end
696
+ end
697
+
698
+ function tuplehead (t:: Tuple , N:: Val )
699
+ return ntuple (i -> t[i], N)
700
+ end
701
+ @generated function tupletail (t:: NTuple{M, Any} , :: Val{N} ) where {N, M}
702
+ # alternative, non-generated versions,
703
+ # enable when inference is improved:
704
+ # tupletail(t, Nreq) = ntuple(i -> t[i + Nreq], length(t) - Nreq)
705
+ # tupletail(t, Nreq) = t[(Nreq + 1):end]
706
+ args = Any[ :(getfield (t, $ i)) for i in (N + 1 ): M ]
707
+ tpl = Expr (:tuple )
708
+ tpl. args = args
709
+ return tpl
710
+ end
711
+
712
+ @inline (f:: Fusion{N, false} )(args:: Vararg{Any, N} ) where {N} = f. f (args... )
713
+ function (f:: Fusion{Nreq, true} )(args:: Vararg{Any, M} ) where {Nreq, M}
714
+ M >= Nreq || throw (MethodError (f, args))
715
+ fargs = tuplehead (args, Val (Nreq))
716
+ vararg = tupletail (args, Val (Nreq))
717
+ return f. f ((fargs... , vararg). .. )
718
+ end
719
+ @inline (f:: FusionArg{N} )(args... ) where {N} = args[N]
720
+ @inline (f:: FusionConstant )(args... ) = f. c
721
+ @inline (f:: FusionCall )(args... ) = f. f (map (a -> a (args... ), f. args)... )
722
+ # TODO : calling _apply on map _apply is not handled by inference
723
+ # for now, we unroll some cases and generate others, to help it out
724
+ # @inline (f::FusionApply)(args...) = Core._apply(f.f, map(a -> a(args...), f.args)...)
725
+ @inline (f:: FusionApply{0} )(args... ) = f. f ()
726
+ @inline (f:: FusionApply{1} )(args... ) = f. f (f. args[1 ](args... ). .. )
727
+ @inline (f:: FusionApply{2} )(args... ) = f. f (f. args[1 ](args... ). .. , f. args[2 ](args... ). .. )
728
+ @inline (f:: FusionApply{3} )(args... ) = f. f (f. args[1 ](args... ). .. , f. args[2 ](args... ). .. , f. args[3 ](args... ). .. )
729
+ @generated function (f:: FusionApply{N} )(args... ) where {N}
730
+ fargs = Any[ :(getfield (f. args, $ i)(args... )) for i in 1 : N ]
731
+ return Expr (:call , GlobalRef (Core, :_apply ), :(f. f), fargs... )
732
+ end
733
+ @inline function (f:: FusionKWCall )(args... )
734
+ fargs = map (a -> a (args... ), f. args)
735
+ # return f.f(args...; kwargs...)
736
+ if isempty (f. kwargs)
737
+ return f. f (fargs... )
738
+ else
739
+ return Core. kwfunc (f. f)(f. kwargs, f. f, fargs... )
740
+ end
741
+ end
742
+ @inline function (f:: FusionKWApply )(args... )
743
+ fargs = map (a -> a (args... ), f. args)
744
+ # return Core._apply(f.f, args...; kwargs...)
745
+ if isempty (f. kwargs)
746
+ return Core. _apply (f. f, fargs... )
747
+ else
748
+ return Core. _apply (Core. kwfunc (f. f), (f. kwargs,), (f. f,), fargs... )
749
+ end
750
+ end
751
+
752
+ function Base. show (io:: IO , f:: Fusion{N, vararg} ) where {N, vararg}
753
+ nargs = (vararg ? N + 1 : N)
754
+ names = String[ " a_$i " for i in 1 : nargs ] # f.names
755
+ print (io, " (" )
756
+ join (io, names, " , " )
757
+ vararg && print (io, " ..." )
758
+ print (io, " ) -> " )
759
+ show_fusion (io, f. f, names)
760
+ end
761
+
762
+ function show_fusion (io:: IO , f:: FusionArg{N} , names) where N
763
+ print (io, names[N])
764
+ nothing
765
+ end
766
+
767
+ function show_fusion (io:: IO , f:: FusionConstant{N} , names) where N
768
+ print (io, f. c)
769
+ nothing
770
+ end
771
+
772
+ function show_fusion (io:: IO , f:: FusionCall , names)
773
+ Base. show (io, f. f)
774
+ print (io, ' (' )
775
+ first = true
776
+ for i in f. args
777
+ first || print (io, " , " )
778
+ first = false
779
+ show_fusion (io, i, names)
780
+ end
781
+ print (io, ' )' )
782
+ nothing
783
+ end
784
+
785
+ function show_fusion (io:: IO , f:: FusionApply , names)
786
+ print (io, " Core._apply(" )
787
+ Base. show (io, f. f)
788
+ for i in f. args
789
+ print (io, " , " )
790
+ show_fusion (io, i, names)
791
+ end
792
+ print (io, ' )' )
793
+ nothing
794
+ end
795
+
796
+ function show_fusion (io:: IO , f:: FusionKWCall , names)
797
+ Base. show (io, f. f)
798
+ print (io, ' (' )
799
+ first = true
800
+ for i in f. args
801
+ first || print (io, " , " )
802
+ first = false
803
+ show_fusion (io, i, names)
804
+ end
805
+ print (io, " ; " )
806
+ first = true
807
+ for i in 1 : 2 : length (f. kwargs)
808
+ first || print (io, " , " )
809
+ first = false
810
+ print (io, f. kwargs[i])
811
+ print (io, " =" )
812
+ end
813
+ print (io, ' )' )
814
+ nothing
815
+ end
816
+
817
+
818
+ function show_fusion (io:: IO , f:: FusionKWApply , names)
819
+ print (io, " Core._apply(" )
820
+ Base. show (io, f. f)
821
+ for i in f. args
822
+ print (io, " , " )
823
+ show_fusion (io, i, names)
824
+ end
825
+ print (io, " ; #=kwargs=#...)" )
826
+ nothing
827
+ end
828
+
829
+
830
+ function show_fusion (io:: IO , @nospecialize (f), names)
831
+ print (io, " #= unexpected expression " )
832
+ show (io, f)
833
+ print (io, " =#" )
834
+ nothing
835
+ end
836
+
612
837
end # module
0 commit comments