Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

Commit 6cb1352

Browse files
committed
Contextualization of kernels with Cassette
1 parent 2014ec6 commit 6cb1352

File tree

5 files changed

+88
-6
lines changed

5 files changed

+88
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.1"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
8+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0"
1011
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

src/compiler/common.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ struct CompilerJob
77
device::RuntimeDevice
88
kernel::Bool
99

10+
contextualize::Bool
11+
1012
# optional properties
1113
minthreads::Union{Nothing,ROCDim}
1214
maxthreads::Union{Nothing,ROCDim}
@@ -16,8 +18,10 @@ struct CompilerJob
1618

1719
CompilerJob(f, tt, device, kernel; name=nothing,
1820
minthreads=nothing, maxthreads=nothing,
19-
blocks_per_sm=nothing, maxregs=nothing) =
20-
new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm,
21+
blocks_per_sm=nothing, maxregs=nothing,
22+
contextualize=true) =
23+
new(f, tt, device, kernel, contextualize,
24+
minthreads, maxthreads, blocks_per_sm,
2125
maxregs, name)
2226
end
2327

src/compiler/driver.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
5151
@timeit to[] "validation" check_method(job)
5252

5353
@timeit to[] "Julia front-end" begin
54+
f = job.contextualize ? contextualize(job.f) : job.f
5455

5556
# get the method instance
5657
world = typemax(UInt)
57-
meth = which(job.f, job.tt)
58-
sig = Base.signature_type(job.f, job.tt)::Type
58+
meth = which(f, job.tt)
59+
sig = Base.signature_type(f, job.tt)::Type
5960
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
6061
(Any, Any), sig, meth.sig)::Core.SimpleVector
6162
if VERSION >= v"1.2.0-DEV.320"

src/context.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
##
2+
# Implements contextual dispatch through Cassette.jl
3+
# Goals:
4+
# - Rewrite common CPU functions to appropriate GPU intrinsics
5+
#
6+
# TODO:
7+
# - error (erf, ...)
8+
# - pow
9+
# - min, max
10+
# - mod, rem
11+
# - gamma
12+
# - bessel
13+
# - distributions
14+
# - unsorted
15+
16+
using Cassette
17+
18+
function transform(ctx, ref)
19+
CI = ref.code_info
20+
noinline = any(@nospecialize(x) ->
21+
Core.Compiler.isexpr(x, :meta) &&
22+
x.args[1] == :noinline,
23+
CI.code)
24+
CI.inlineable = !noinline
25+
26+
CI.ssavaluetypes = length(CI.code)
27+
# Core.Compiler.validate_code(CI)
28+
return CI
29+
end
30+
31+
const InlinePass = Cassette.@pass transform
32+
33+
Cassette.@context ROCCtx
34+
const rocctx = Cassette.disablehooks(ROCCtx(pass = InlinePass))
35+
36+
###
37+
# Cassette fixes
38+
###
39+
40+
# kwfunc fix
41+
Cassette.overdub(::ROCCtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f)
42+
43+
# the functions below are marked `@pure` and by rewritting them we hide that from
44+
# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108).
45+
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x)
46+
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t)
47+
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t)
48+
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t)
49+
@inline Cassette.overdub(::ROCCtx, ::typeof(Base.isbits), x) = return Base.isbits(x)
50+
51+
@inline Cassette.overdub(::ROCCtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)
52+
53+
###
54+
# Rewrite functions
55+
###
56+
Cassette.overdub(ctx::ROCCtx, ::typeof(isdevice)) = true
57+
58+
# libdevice.jl
59+
for f in (:cos, :cospi, :sin, :sinpi, :tan,
60+
:acos, :asin, :atan,
61+
:cosh, :sinh, :tanh,
62+
:acosh, :asinh, :atanh,
63+
:log, :log10, :log1p, :log2,
64+
:exp, :exp2, :exp10, :expm1, :ldexp,
65+
:isfinite, :isinf, :isnan,
66+
:signbit, :abs,
67+
:sqrt, :cbrt,
68+
:ceil, :floor,)
69+
@eval function Cassette.overdub(ctx::ROCCtx, ::typeof(Base.$f), x::Union{Float32, Float64})
70+
@Base._inline_meta
71+
return AMDGPUnative.$f(x)
72+
end
73+
end
74+
75+
contextualize(f::F) where F = (args...) -> Cassette.overdub(rocctx, f, args...)

src/execution.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212
# affecting the compiler, kernel execution, or both.
1313
function split_kwargs(kwargs)
1414
# TODO: Alias groupsize and gridsize as threads and blocks, respectively
15-
compiler_kws = [:device, :agent, :queue, :name]
15+
compiler_kws = [:device, :agent, :queue, :name, :contextualize]
1616
call_kws = [:groupsize, :gridsize, :device, :agent, :queue]
1717
compiler_kwargs = []
1818
call_kwargs = []
@@ -199,7 +199,8 @@ Low-level interface to compile a function invocation for the currently-active
199199
GPU, returning a callable kernel object. For a higher-level interface, use
200200
[`@roc`](@ref).
201201
202-
Currently, no keyword arguments are implemented.
202+
The following keyword arguments are supported:
203+
- `contextualize`: whether to contextualize functions using Cassette (default: true)
203204
204205
The output of this function is automatically cached, i.e. you can simply call
205206
`rocfunction` in a hot path without degrading performance. New code will be

0 commit comments

Comments
 (0)