From 4e23a04170540dd69d549df941dde2c32124a9e6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 7 Mar 2025 18:40:17 -0600 Subject: [PATCH] Move executable and device to thunk from expr (#855) * Move executable and device to thunk from expr * fix err * fix err * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 9f6eb40a7..230c3a2b8 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1391,14 +1391,11 @@ Generate Julia code to call the XLA executable. # Arguments -- `exec`: The XLA executable to call. - `flatten_names`: A list of `Symbol`s representing the names of the flattened linear arguments. - `donated_args_mask`: A list of `UInt8`s representing whether the argument is donated. - `nresults`: The number of results to expect. """ function codegen_xla_call( - exec, - device, flatten_names, donated_args_mask, nresults, @@ -1420,7 +1417,7 @@ function codegen_xla_call( quote GC.@preserve $(flatten_names...) begin linearized_results = XLA.execute( - $exec, + thunk.exec, ($(flatten_buffer_refs...),), $(Tuple(donated_args_mask)), Val($nresults), @@ -1433,8 +1430,8 @@ function codegen_xla_call( quote GC.@preserve $(flatten_names...) begin linearized_results = XLA.execute_sharded( - $exec, - $(device), + thunk.exec, + thunk.device, ($(flatten_buffer_refs...),), $(Tuple(donated_args_mask)), Val($nresults), @@ -1600,8 +1597,6 @@ function compile(f, args; sync=false, kwargs...) ) concretized_res_names, xla_call_code = codegen_xla_call( - exec, - device, flatten_arg_names, donated_args_mask, length(linear_results), @@ -1653,15 +1648,23 @@ function compile(f, args; sync=false, kwargs...) end return register_thunk( - fname, Tuple{map(Core.Typeof, args)...}, body, f, mlir_fn_res.fnwrapped + fname, + Tuple{map(Core.Typeof, args)...}, + body, + f, + mlir_fn_res.fnwrapped, + exec, + mlir_fn_res.is_sharded ? nothing : device, ) end # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() -struct Thunk{FTy,tag,IsClosure,ArgTypes} +struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy} f::FTy + exec::ExecTy + device::DeviceTy end struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end @@ -1687,14 +1690,16 @@ function Base.showerror( ) end -@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})( +@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy})( args... -) where {FTy,tag,ArgTypes,IsClosure} +) where {FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy} FoundTypes = Tuple{args...} if ArgTypes != FoundTypes return quote throw( - $(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}()) + $(MisMatchedThunkTypeError{ + Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes + }()), ) end end @@ -1710,10 +1715,18 @@ end end function register_thunk( - tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool + tag::Symbol, + @nospecialize(argtys::Type), + body::Expr, + @nospecialize(f), + isclosure::Bool, + exec, + device, ) __thunk_body_cache[tag] = body - return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f) + return Thunk{Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)}( + f, exec, device + ) end for cache_type in (:callcache, :sdycache)