Skip to content

Commit ffbedbb

Browse files
authored
1.11 enablement (#497)
1 parent 77247f7 commit ffbedbb

File tree

9 files changed

+425
-456
lines changed

9 files changed

+425
-456
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@ jobs:
4646
strategy:
4747
fail-fast: false
4848
matrix:
49-
branch: ['release-1.6', 'release-1.7', 'release-1.8', 'release-1.9', 'master']
49+
branch: ['release-1.8', 'release-1.9', 'release-1.10', 'master']
5050
os: ['ubuntu-latest', 'macOS-latest', 'windows-latest']
5151
arch: [x64]
5252
exclude:
53+
# unknown segfault in LLVM
54+
- branch: 'release-1.10'
55+
os: 'windows-latest'
56+
arch: 'x64'
5357
# JuliaLang/julia#48081
5458
- branch: 'master'
5559
os: 'windows-latest'
@@ -58,10 +62,6 @@ jobs:
5862
- branch: 'master'
5963
os: 'macOS-latest'
6064
arch: 'x64'
61-
# 1.6 requires gfortran, which isn't available on macOS runners
62-
- branch: 'release-1.6'
63-
os: 'macOS-latest'
64-
arch: 'x64'
6565
steps:
6666
- uses: actions/checkout@v3
6767
- uses: actions/checkout@v3

examples/jit.jl

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# example how to integrate GPUCompiler.jl with an LLVM Orc-based JIT
2+
3+
using GPUCompiler
4+
5+
module TestRuntime
6+
# dummy methods
7+
signal_exception() = return
8+
malloc(sz) = C_NULL
9+
report_oom(sz) = return
10+
report_exception(ex) = return
11+
report_exception_name(ex) = return
12+
report_exception_frame(idx, func, file, line) = return
13+
end
14+
15+
struct TestCompilerParams <: AbstractCompilerParams end
16+
GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime
17+
18+
19+
## JIT integration
20+
21+
using LLVM, LLVM.Interop
22+
23+
# XXX: this example has bitrotten, due to many changes to Julia's JIT.
24+
if LLVM.is_asserts()
25+
@error "The JIT example fails LLVM assertions, and is therefor disabled."
26+
exit(0)
27+
end
28+
if VERSION >= v"1.10-" && Sys.iswindows()
29+
@error "The JIT example fails on Windows with Julia 1.10+, and is therefor disabled."
30+
exit(0)
31+
end
32+
33+
function absolute_symbol_materialization(name, ptr)
34+
address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr))
35+
flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
36+
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
37+
gv = if LLVM.version() >= v"15"
38+
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
39+
else
40+
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
41+
end
42+
43+
return LLVM.absolute_symbols(Ref(gv))
44+
end
45+
46+
function define_absolute_symbol(jd, name)
47+
ptr = LLVM.find_symbol(name)
48+
if ptr !== C_NULL
49+
LLVM.define(jd, absolute_symbol_materialization(name, ptr))
50+
return true
51+
end
52+
return false
53+
end
54+
55+
struct CompilerInstance
56+
jit::LLVM.LLJIT
57+
lctm::LLVM.LazyCallThroughManager
58+
ism::LLVM.IndirectStubsManager
59+
end
60+
const jit = Ref{CompilerInstance}()
61+
62+
function get_trampoline(job)
63+
compiler = jit[]
64+
lljit = compiler.jit
65+
lctm = compiler.lctm
66+
ism = compiler.ism
67+
68+
# We could also use one dylib per job
69+
jd = JITDylib(lljit)
70+
71+
entry_sym = String(gensym(:entry))
72+
target_sym = String(gensym(:target))
73+
flags = LLVM.API.LLVMJITSymbolFlags(
74+
LLVM.API.LLVMJITSymbolGenericFlagsCallable |
75+
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
76+
entry = LLVM.API.LLVMOrcCSymbolAliasMapPair(
77+
mangle(lljit, entry_sym),
78+
LLVM.API.LLVMOrcCSymbolAliasMapEntry(
79+
mangle(lljit, target_sym), flags))
80+
81+
mu = LLVM.reexports(lctm, ism, jd, Ref(entry))
82+
LLVM.define(jd, mu)
83+
84+
# 2. Lookup address of entry symbol
85+
addr = lookup(lljit, entry_sym)
86+
87+
# 3. add MU that will call back into the compiler
88+
sym = LLVM.API.LLVMOrcCSymbolFlagsMapPair(mangle(lljit, target_sym), flags)
89+
90+
function materialize(mr)
91+
buf = JuliaContext() do ctx
92+
ir, meta = GPUCompiler.compile(:llvm, job; validate=false)
93+
94+
# Rename entry to match target_sym
95+
LLVM.name!(meta.entry, target_sym)
96+
97+
# So 1. serialize the module
98+
buf = convert(MemoryBuffer, ir)
99+
100+
# 2. deserialize and wrap by a ThreadSafeModule
101+
ThreadSafeContext() do ts_ctx
102+
tsm = context!(context(ts_ctx)) do
103+
mod = parse(LLVM.Module, buf)
104+
ThreadSafeModule(mod)
105+
end
106+
107+
il = LLVM.IRTransformLayer(lljit)
108+
LLVM.emit(il, mr, tsm)
109+
end
110+
end
111+
112+
return nothing
113+
end
114+
115+
function discard(jd, sym)
116+
end
117+
118+
mu = LLVM.CustomMaterializationUnit(entry_sym, Ref(sym), materialize, discard)
119+
LLVM.define(jd, mu)
120+
return addr
121+
end
122+
123+
import GPUCompiler: deferred_codegen_jobs
124+
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
125+
# manual version of native_job because we have a function type
126+
source = methodinstance(F, Base.to_tuple_type(tt), world)
127+
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
128+
# XXX: do we actually require the Julia runtime?
129+
# with jlruntime=false, we reach an unreachable.
130+
params = TestCompilerParams()
131+
config = CompilerConfig(target, params; kernel=false)
132+
job = CompilerJob(source, config, world)
133+
# XXX: invoking GPUCompiler from a generated function is not allowed!
134+
# for things to work, we need to forward the correct world, at least.
135+
136+
addr = get_trampoline(job)
137+
trampoline = pointer(addr)
138+
id = Base.reinterpret(Int, trampoline)
139+
140+
deferred_codegen_jobs[id] = job
141+
142+
quote
143+
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
144+
assume(ptr != C_NULL)
145+
return ptr
146+
end
147+
end
148+
149+
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
150+
argtt = tt.parameters[1]
151+
rettype = rt.parameters[1]
152+
argtypes = DataType[argtt.parameters...]
153+
154+
argexprs = Union{Expr, Symbol}[]
155+
ccall_types = DataType[]
156+
157+
before = :()
158+
after = :(ret)
159+
160+
# Note this follows: emit_call_specfun_other
161+
JuliaContext() do ctx
162+
if !isghosttype(F) && !Core.Compiler.isconstType(F)
163+
isboxed = GPUCompiler.deserves_argbox(F)
164+
argexpr = :(func)
165+
if isboxed
166+
push!(ccall_types, Any)
167+
else
168+
et = convert(LLVMType, func)
169+
if isa(et, LLVM.SequentialType) # et->isAggregateType
170+
push!(ccall_types, Ptr{F})
171+
argexpr = Expr(:call, GlobalRef(Base, :Ref), argexpr)
172+
else
173+
push!(ccall_types, F)
174+
end
175+
end
176+
push!(argexprs, argexpr)
177+
end
178+
179+
T_jlvalue = LLVM.StructType(LLVMType[])
180+
T_prjlvalue = LLVM.PointerType(T_jlvalue, #= AddressSpace::Tracked =# 10)
181+
182+
for (source_i, source_typ) in enumerate(argtypes)
183+
if GPUCompiler.isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
184+
continue
185+
end
186+
187+
argexpr = :(args[$source_i])
188+
189+
isboxed = GPUCompiler.deserves_argbox(source_typ)
190+
et = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)
191+
192+
if isboxed
193+
push!(ccall_types, Any)
194+
elseif isa(et, LLVM.SequentialType) # et->isAggregateType
195+
push!(ccall_types, Ptr{source_typ})
196+
argexpr = Expr(:call, GlobalRef(Base, :Ref), argexpr)
197+
else
198+
push!(ccall_types, source_typ)
199+
end
200+
push!(argexprs, argexpr)
201+
end
202+
203+
if GPUCompiler.isghosttype(rettype) || Core.Compiler.isconstType(rettype)
204+
# Do nothing...
205+
# In theory we could set `rettype` to `T_void`, but ccall will do that for us
206+
# elseif jl_is_uniontype?
207+
elseif !GPUCompiler.deserves_retbox(rettype)
208+
rt = convert(LLVMType, rettype)
209+
if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(rettype, rt)
210+
before = :(sret = Ref{$rettype}())
211+
pushfirst!(argexprs, :(sret))
212+
pushfirst!(ccall_types, Ptr{rettype})
213+
rettype = Nothing
214+
after = :(sret[])
215+
end
216+
else
217+
# rt = T_prjlvalue
218+
end
219+
end
220+
221+
quote
222+
$before
223+
ret = ccall(f, $rettype, ($(ccall_types...),), $(argexprs...))
224+
$after
225+
end
226+
end
227+
228+
@inline function call_delayed(f::F, args...) where F
229+
tt = Tuple{map(Core.Typeof, args)...}
230+
rt = Core.Compiler.return_type(f, tt)
231+
world = GPUCompiler.tls_world_age()
232+
ptr = deferred_codegen(f, Val(tt), Val(world))
233+
abi_call(ptr, rt, tt, f, args...)
234+
end
235+
236+
optlevel = LLVM.API.LLVMCodeGenLevelDefault
237+
tm = GPUCompiler.JITTargetMachine(optlevel=optlevel)
238+
LLVM.asm_verbosity!(tm, true)
239+
240+
lljit = LLJIT(;tm)
241+
242+
jd_main = JITDylib(lljit)
243+
244+
prefix = LLVM.get_prefix(lljit)
245+
dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
246+
add!(jd_main, dg)
247+
if Sys.iswindows() && Int === Int64
248+
# TODO can we check isGNU?
249+
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
250+
end
251+
252+
es = ExecutionSession(lljit)
253+
254+
lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es)
255+
ism = LLVM.LocalIndirectStubsManager(triple(lljit))
256+
257+
jit[] = CompilerInstance(lljit, lctm, ism)
258+
atexit() do
259+
ci = jit[]
260+
dispose(ci.ism)
261+
dispose(ci.lctm)
262+
dispose(ci.jit)
263+
end
264+
265+
266+
## demo
267+
268+
using Test
269+
270+
# smoke test
271+
f(A) = (A[] += 42; nothing)
272+
global flag = [0]
273+
function caller()
274+
call_delayed(f, flag::Vector{Int})
275+
end
276+
@test caller() === nothing
277+
@test flag[] == 42
278+
279+
# test that we can call a function with a return value
280+
add(x, y) = x+y
281+
function call_add(x, y)
282+
call_delayed(add, x, y)
283+
end
284+
@test call_add(1, 3) == 4
285+
286+
incr(r) = r[] += 1
287+
function call_incr(r)
288+
call_delayed(incr, r)
289+
end
290+
r = Ref{Int}(0)
291+
@test call_incr(r) == 1
292+
@test r[] == 1
293+
294+
function call_real(c)
295+
call_delayed(real, c)
296+
end
297+
@test call_real(1.0+im) == 1.0
298+
299+
# tests struct return
300+
if Sys.ARCH != :aarch64
301+
@test call_delayed(complex, 1.0, 2.0) == 1.0+2.0im
302+
else
303+
@test_broken call_delayed(complex, 1.0, 2.0) == 1.0+2.0im
304+
end
305+
306+
throws(arr, i) = arr[i]
307+
@test call_delayed(throws, [1], 1) == 1
308+
@test_throws BoundsError call_delayed(throws, [1], 0)
309+
310+
struct Closure
311+
x::Int64
312+
end
313+
(c::Closure)(b) = c.x+b
314+
@test call_delayed(Closure(3), 5) == 8
315+
316+
struct Closure2
317+
x::Integer
318+
end
319+
(c::Closure2)(b) = c.x+b
320+
@test call_delayed(Closure2(3), 5) == 8

src/metal.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ end
171171
input = tempname(cleanup=false) * ".bc"
172172
translated = tempname(cleanup=false) * ".metallib"
173173
write(input, mod)
174-
Metal_LLVM_Tools_jll.metallib_as() do assembler
175-
proc = run(ignorestatus(`$assembler -o $translated $input`))
174+
let cmd = `$(Metal_LLVM_Tools_jll.metallib_as()) -o $translated $input`
175+
proc = run(ignorestatus(cmd))
176176
if !success(proc)
177177
error("""Failed to translate LLVM code to MetalLib.
178178
If you think this is a bug, please file an issue and attach $(input).""")
@@ -183,9 +183,7 @@ end
183183
read(translated)
184184
else
185185
# disassemble
186-
Metal_LLVM_Tools_jll.metallib_dis() do disassembler
187-
read(`$disassembler -o - $translated`, String)
188-
end
186+
read(`$(Metal_LLVM_Tools_jll.metallib_dis()) -o - $translated`, String)
189187
end
190188

191189
rm(input)

0 commit comments

Comments
 (0)