Skip to content

Commit

Permalink
Move executable and device to thunk from expr (#855)
Browse files Browse the repository at this point in the history
* Move executable and device to thunk from expr

* fix err

* fix err

* fixup

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Mar 8, 2025
1 parent 9177412 commit 4e23a04
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1391,14 +1391,11 @@ Generate Julia code to call the XLA executable.
# Arguments
- `exec`: The XLA executable to call.
- `flatten_names`: A list of `Symbol`s representing the names of the flattened linear arguments.
- `donated_args_mask`: A list of `UInt8`s representing whether the argument is donated.
- `nresults`: The number of results to expect.
"""
function codegen_xla_call(
exec,
device,
flatten_names,
donated_args_mask,
nresults,
Expand All @@ -1420,7 +1417,7 @@ function codegen_xla_call(
quote
GC.@preserve $(flatten_names...) begin
linearized_results = XLA.execute(
$exec,
thunk.exec,
($(flatten_buffer_refs...),),
$(Tuple(donated_args_mask)),
Val($nresults),
Expand All @@ -1433,8 +1430,8 @@ function codegen_xla_call(
quote
GC.@preserve $(flatten_names...) begin
linearized_results = XLA.execute_sharded(
$exec,
$(device),
thunk.exec,
thunk.device,
($(flatten_buffer_refs...),),
$(Tuple(donated_args_mask)),
Val($nresults),
Expand Down Expand Up @@ -1600,8 +1597,6 @@ function compile(f, args; sync=false, kwargs...)
)

concretized_res_names, xla_call_code = codegen_xla_call(
exec,
device,
flatten_arg_names,
donated_args_mask,
length(linear_results),
Expand Down Expand Up @@ -1653,15 +1648,23 @@ function compile(f, args; sync=false, kwargs...)
end

return register_thunk(
fname, Tuple{map(Core.Typeof, args)...}, body, f, mlir_fn_res.fnwrapped
fname,
Tuple{map(Core.Typeof, args)...},
body,
f,
mlir_fn_res.fnwrapped,
exec,
mlir_fn_res.is_sharded ? nothing : device,
)
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

struct Thunk{FTy,tag,IsClosure,ArgTypes}
struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
f::FTy
exec::ExecTy
device::DeviceTy
end

struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
Expand All @@ -1687,14 +1690,16 @@ function Base.showerror(
)
end

@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})(
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy})(
args...
) where {FTy,tag,ArgTypes,IsClosure}
) where {FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy}
FoundTypes = Tuple{args...}
if ArgTypes != FoundTypes
return quote
throw(
$(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}())
$(MisMatchedThunkTypeError{
Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
}()),
)
end
end
Expand All @@ -1710,10 +1715,18 @@ end
end

function register_thunk(
tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool
tag::Symbol,
@nospecialize(argtys::Type),
body::Expr,
@nospecialize(f),
isclosure::Bool,
exec,
device,
)
__thunk_body_cache[tag] = body
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
return Thunk{Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)}(
f, exec, device
)
end

for cache_type in (:callcache, :sdycache)
Expand Down

0 comments on commit 4e23a04

Please sign in to comment.