diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index a0275dea3..ecd144456 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -225,6 +225,20 @@ struct LLVMFunc{F,tt} entry::String end +function Base.getproperty(f::LLVMFunc{F, tt}, sym::Symbol) where {F, tt} + if sym === :fun + f + else + Base.getfield(f, sym) + end +end + +# TODO in the future we may want to avoid doing a second cufunction compilation +# for computing the thread/block count (or potentially do it ourselves). +@noinline function CUDA.launch_configuration(f::LLVMFunc{F, tt}; shmem::Union{Integer, Base.Callable}=0, max_threads::Integer=0) where {F, tt} + CUDA.launch_configuration(Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads) +end + const GPUCompiler = CUDA.GPUCompiler const LLVM = GPUCompiler.LLVM @@ -456,7 +470,7 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( ) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end - return res + return Core.Typeof(res)(f, res.entry) end function Reactant.traced_type( diff --git a/src/utils.jl b/src/utils.jl index e2f518cdf..ebc937cac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -105,6 +105,11 @@ function should_rewrite_ft(@nospecialize(ft)) has_ancestor(mod, Reactant.TracedRandom) return false end + if string(mod) == "CUDA" + if ft.name.name == Symbol("#launch_configuration") + return false + end + end end end # Don't rewrite Val @@ -153,6 +158,8 @@ function should_rewrite_ft(@nospecialize(ft)) return false end + + # Default assume all functions need to be reactant-ified return true end @@ -217,7 +224,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) end if ft == typeof(Core._apply_iterate) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - if should_rewrite_ft(ft) + if Base.invokelatest(should_rewrite_ft, ft) if RT === Union{} rep = Expr( :call, @@ -231,7 +238,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) return true, rep, Any end end - elseif should_rewrite_ft(ft) + elseif Base.invokelatest(should_rewrite_ft, ft) if RT === Union{} rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) return true, rep, Union{} @@ -248,7 +255,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) if ft == typeof(Core.kwcall) ft = sig.parameters[3] end - if should_rewrite_ft(ft) && !is_reactant_method(omi) + if Base.invokelatest(should_rewrite_ft, ft) && !is_reactant_method(omi) method = omi.def::Core.Method min_world = Ref{UInt}(typemin(UInt)) @@ -479,9 +486,15 @@ function call_with_reactant_generator( return stub(world, source, builtin_error) end - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) - )) + if guaranteed_error + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world) + )) + else + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) + end interp = ReactantInterpreter(; world) @@ -675,7 +688,7 @@ function call_with_reactant_generator( dict, make_oc = if Base.issingletontype(fn) Base.Ref{Core.OpaqueClosure}(), make_oc_ref else - Dict{args[1],Core.OpaqueClosure}(), make_oc_dict + Dict{fn,Core.OpaqueClosure}(), make_oc_dict end push!(oc_capture_vec, dict)