Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

[WIP] Contextualization of kernels with Cassette #44

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.1"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
2 changes: 2 additions & 0 deletions src/AMDGPUnative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Adaptor end

# Device sources must load _before_ the compiler infrastructure
# because of generated functions.
isdevice() = false
include(joinpath("device", "tools.jl"))
include(joinpath("device", "pointer.jl"))
include(joinpath("device", "array.jl"))
Expand All @@ -36,6 +37,7 @@ include(joinpath("device", "runtime.jl"))

include("execution_utils.jl")
include("compiler.jl")
include("context.jl")
include("execution.jl")
include("reflection.jl")

Expand Down
8 changes: 6 additions & 2 deletions src/compiler/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ struct CompilerJob
device::RuntimeDevice
kernel::Bool

contextualize::Bool

# optional properties
minthreads::Union{Nothing,ROCDim}
maxthreads::Union{Nothing,ROCDim}
Expand All @@ -16,8 +18,10 @@ struct CompilerJob

CompilerJob(f, tt, device, kernel; name=nothing,
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing) =
new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm,
blocks_per_sm=nothing, maxregs=nothing,
contextualize=true) =
new(f, tt, device, kernel, contextualize,
minthreads, maxthreads, blocks_per_sm,
maxregs, name)
end

Expand Down
5 changes: 3 additions & 2 deletions src/compiler/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
@timeit to[] "validation" check_method(job)

@timeit to[] "Julia front-end" begin
f = job.contextualize ? contextualize(job.f) : job.f

# get the method instance
world = typemax(UInt)
meth = which(job.f, job.tt)
sig = Base.signature_type(job.f, job.tt)::Type
meth = which(f, job.tt)
sig = Base.signature_type(f, job.tt)::Type
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
if VERSION >= v"1.2.0-DEV.320"
Expand Down
75 changes: 75 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
##
# Implements contextual dispatch through Cassette.jl
# Goals:
# - Rewrite common CPU functions to appropriate GPU intrinsics
#
# TODO:
# - error (erf, ...)
# - pow
# - min, max
# - mod, rem
# - gamma
# - bessel
# - distributions
# - unsorted

using Cassette

function transform(ctx, ref)
CI = ref.code_info
noinline = any(@nospecialize(x) ->
Core.Compiler.isexpr(x, :meta) &&
x.args[1] == :noinline,
CI.code)
CI.inlineable = !noinline

CI.ssavaluetypes = length(CI.code)
# Core.Compiler.validate_code(CI)
return CI
end

const InlinePass = Cassette.@pass transform

Cassette.@context ROCCtx
const rocctx = Cassette.disablehooks(ROCCtx(pass = InlinePass))

###
# Cassette fixes
###

# kwfunc fix
Cassette.overdub(::ROCCtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f)

# the functions below are marked `@pure` and by rewritting them we hide that from
# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108).
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x)
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t)
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t)
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t)
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbits), x) = return Base.isbits(x)

@inline Cassette.overdub(::ROCCtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)

###
# Rewrite functions
###
Cassette.overdub(ctx::ROCCtx, ::typeof(isdevice)) = true

# libdevice.jl
for f in (:cos, :cospi, :sin, :sinpi, :tan,
:acos, :asin, :atan,
:cosh, :sinh, :tanh,
:acosh, :asinh, :atanh,
:log, :log10, :log1p, :log2,
:exp, :exp2, :exp10, :expm1, :ldexp,
:isfinite, :isinf, :isnan,
:signbit, :abs,
:sqrt, :cbrt,
:ceil, :floor,)
@eval function Cassette.overdub(ctx::ROCCtx, ::typeof(Base.$f), x::Union{Float32, Float64})
@Base._inline_meta
return AMDGPUnative.$f(x)
end
end

contextualize(f::F) where F = (args...) -> Cassette.overdub(rocctx, f, args...)
5 changes: 3 additions & 2 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
# affecting the compiler, kernel execution, or both.
function split_kwargs(kwargs)
# TODO: Alias groupsize and gridsize as threads and blocks, respectively
compiler_kws = [:device, :agent, :queue, :name]
compiler_kws = [:device, :agent, :queue, :name, :contextualize]
call_kws = [:groupsize, :gridsize, :device, :agent, :queue]
compiler_kwargs = []
call_kwargs = []
Expand Down Expand Up @@ -199,7 +199,8 @@ Low-level interface to compile a function invocation for the currently-active
GPU, returning a callable kernel object. For a higher-level interface, use
[`@roc`](@ref).

Currently, no keyword arguments are implemented.
The following keyword arguments are supported:
- `contextualize`: whether to contextualize functions using Cassette (default: true)

The output of this function is automatically cached, i.e. you can simply call
`rocfunction` in a hot path without degrading performance. New code will be
Expand Down