diff --git a/Project.toml b/Project.toml index fb78ede..268c8de 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" diff --git a/src/AMDGPUnative.jl b/src/AMDGPUnative.jl index a209342..5bda187 100644 --- a/src/AMDGPUnative.jl +++ b/src/AMDGPUnative.jl @@ -28,6 +28,7 @@ struct Adaptor end # Device sources must load _before_ the compiler infrastructure # because of generated functions. +isdevice() = false include(joinpath("device", "tools.jl")) include(joinpath("device", "pointer.jl")) include(joinpath("device", "array.jl")) @@ -36,6 +37,7 @@ include(joinpath("device", "runtime.jl")) include("execution_utils.jl") include("compiler.jl") +include("context.jl") include("execution.jl") include("reflection.jl") diff --git a/src/compiler/common.jl b/src/compiler/common.jl index 86c5067..e8ebac1 100644 --- a/src/compiler/common.jl +++ b/src/compiler/common.jl @@ -7,6 +7,8 @@ struct CompilerJob device::RuntimeDevice kernel::Bool + contextualize::Bool + # optional properties minthreads::Union{Nothing,ROCDim} maxthreads::Union{Nothing,ROCDim} @@ -16,8 +18,10 @@ struct CompilerJob CompilerJob(f, tt, device, kernel; name=nothing, minthreads=nothing, maxthreads=nothing, - blocks_per_sm=nothing, maxregs=nothing) = - new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm, + blocks_per_sm=nothing, maxregs=nothing, + contextualize=true) = + new(f, tt, device, kernel, contextualize, + minthreads, maxthreads, blocks_per_sm, maxregs, name) end diff --git a/src/compiler/driver.jl b/src/compiler/driver.jl index ceeb10d..3102663 100644 --- a/src/compiler/driver.jl +++ b/src/compiler/driver.jl @@ -51,11 +51,12 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true, @timeit to[] "validation" check_method(job) @timeit to[] "Julia front-end" begin + f = job.contextualize ? contextualize(job.f) : job.f # get the method instance world = typemax(UInt) - meth = which(job.f, job.tt) - sig = Base.signature_type(job.f, job.tt)::Type + meth = which(f, job.tt) + sig = Base.signature_type(f, job.tt)::Type (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), sig, meth.sig)::Core.SimpleVector if VERSION >= v"1.2.0-DEV.320" diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 0000000..34a5c33 --- /dev/null +++ b/src/context.jl @@ -0,0 +1,75 @@ +## +# Implements contextual dispatch through Cassette.jl +# Goals: +# - Rewrite common CPU functions to appropriate GPU intrinsics +# +# TODO: +# - error (erf, ...) +# - pow +# - min, max +# - mod, rem +# - gamma +# - bessel +# - distributions +# - unsorted + +using Cassette + +function transform(ctx, ref) + CI = ref.code_info + noinline = any(@nospecialize(x) -> + Core.Compiler.isexpr(x, :meta) && + x.args[1] == :noinline, + CI.code) + CI.inlineable = !noinline + + CI.ssavaluetypes = length(CI.code) + # Core.Compiler.validate_code(CI) + return CI +end + +const InlinePass = Cassette.@pass transform + +Cassette.@context ROCCtx +const rocctx = Cassette.disablehooks(ROCCtx(pass = InlinePass)) + +### +# Cassette fixes +### + +# kwfunc fix +Cassette.overdub(::ROCCtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f) + +# the functions below are marked `@pure` and by rewritting them we hide that from +# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108). +@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x) +@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t) +@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t) +@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t) +@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbits), x) = return Base.isbits(x) + +@inline Cassette.overdub(::ROCCtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T) + +### +# Rewrite functions +### +Cassette.overdub(ctx::ROCCtx, ::typeof(isdevice)) = true + +# libdevice.jl +for f in (:cos, :cospi, :sin, :sinpi, :tan, + :acos, :asin, :atan, + :cosh, :sinh, :tanh, + :acosh, :asinh, :atanh, + :log, :log10, :log1p, :log2, + :exp, :exp2, :exp10, :expm1, :ldexp, + :isfinite, :isinf, :isnan, + :signbit, :abs, + :sqrt, :cbrt, + :ceil, :floor,) + @eval function Cassette.overdub(ctx::ROCCtx, ::typeof(Base.$f), x::Union{Float32, Float64}) + @Base._inline_meta + return AMDGPUnative.$f(x) + end +end + +contextualize(f::F) where F = (args...) -> Cassette.overdub(rocctx, f, args...) diff --git a/src/execution.jl b/src/execution.jl index 090cf39..3b95eca 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -12,7 +12,7 @@ end # affecting the compiler, kernel execution, or both. function split_kwargs(kwargs) # TODO: Alias groupsize and gridsize as threads and blocks, respectively - compiler_kws = [:device, :agent, :queue, :name] + compiler_kws = [:device, :agent, :queue, :name, :contextualize] call_kws = [:groupsize, :gridsize, :device, :agent, :queue] compiler_kwargs = [] call_kwargs = [] @@ -199,7 +199,8 @@ Low-level interface to compile a function invocation for the currently-active GPU, returning a callable kernel object. For a higher-level interface, use [`@roc`](@ref). -Currently, no keyword arguments are implemented. +The following keyword arguments are supported: +- `contextualize`: whether to contextualize functions using Cassette (default: true) The output of this function is automatically cached, i.e. you can simply call `rocfunction` in a hot path without degrading performance. New code will be