Skip to content

Commit 7782216

Browse files
committed
Add support for dynamically-constructed opaque closures.
1 parent c4d7db3 commit 7782216

File tree

2 files changed

+200
-70
lines changed

2 files changed

+200
-70
lines changed

src/compiler/compilation.jl

Lines changed: 161 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -435,22 +435,7 @@ end
435435
using Core.Compiler: IRCode
436436
using Core: CodeInfo, MethodInstance, CodeInstance, LineNumberNode
437437

438-
struct OpaqueClosure{F, E, A, R} # func, env, args, ret
439-
env::E
440-
end
441-
442-
# XXX: because we can't call functions from other CUDA modules, we effectively need to
443-
# recompile when the target function changes. this, and because of how GPUCompiler's
444-
# deferred compilation mechanism currently works, is why we have `F` as a type param.
445-
446-
# XXX: because of GPU code requiring specialized signatures, we also need to recompile
447-
# when the environment or argument types change. together with the above, this
448-
# negates much of the benefit of opaque closures.
449-
450-
# TODO: support for constructing an opaque closure from source code
451-
452-
# TODO: complete support for passing an environment. this probably requires a split into
453-
# host and device structures to, e.g., root a CuArray and pass a CuDeviceArray.
438+
# helpers
454439

455440
function compute_ir_rettype(ir::IRCode)
456441
rt = Union{}
@@ -463,32 +448,25 @@ function compute_ir_rettype(ir::IRCode)
463448
return Core.Compiler.widenconst(rt)
464449
end
465450

466-
function compute_oc_signature(ir::IRCode, nargs::Int, isva::Bool)
451+
function compute_oc_signature(ir::IRCode, nargs::Int)
467452
argtypes = Vector{Any}(undef, nargs)
468453
for i = 1:nargs
469454
argtypes[i] = Core.Compiler.widenconst(ir.argtypes[i+1])
470455
end
471-
if isva
472-
lastarg = pop!(argtypes)
473-
if lastarg <: Tuple
474-
append!(argtypes, lastarg.parameters)
475-
else
476-
push!(argtypes, Vararg{Any})
477-
end
478-
end
479456
return Tuple{argtypes...}
480457
end
481458

482-
function OpaqueClosure(ir::IRCode, @nospecialize env...;
483-
isva::Bool = false,
484-
slotnames::Union{Nothing,Vector{Symbol}}=nothing)
459+
function make_oc_codeinfo(ir::IRCode, @nospecialize env...; slotnames=nothing)
485460
# NOTE: we need ir.argtypes[1] == typeof(env)
486461
ir = Core.Compiler.copy(ir)
487-
# if the user didn't specify a definition MethodInstance or filename Symbol to use for the debuginfo, set a filename now
488-
ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure")
462+
# if the user didn't specify a definition MethodInstance or filename Symbol to use
463+
# for the debuginfo, set a filename now
464+
if ir.debuginfo.def === nothing
465+
ir.debuginfo.def = Symbol("IR for opaque gpu closure")
466+
end
489467
nargtypes = length(ir.argtypes)
490468
nargs = nargtypes-1
491-
sig = compute_oc_signature(ir, nargs, isva)
469+
sig = compute_oc_signature(ir, nargs)
492470
rt = compute_ir_rettype(ir)
493471
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
494472
if slotnames === nothing
@@ -499,61 +477,39 @@ function OpaqueClosure(ir::IRCode, @nospecialize env...;
499477
end
500478
src.slotflags = Base.fill(zero(UInt8), nargtypes)
501479
src.slottypes = copy(ir.argtypes)
502-
src = Core.Compiler.ir_to_codeinf!(src, ir)
503-
config = compiler_config(device(); kernel=false)
504-
return generate_opaque_closure(config, src, sig, rt, nargs, isva, env...)
505-
end
506-
507-
function OpaqueClosure(src::CodeInfo, @nospecialize env...; rettype, sig, nargs, isva=false)
508-
config = compiler_config(device(); kernel=false)
509-
return generate_opaque_closure(config, src, sig, rettype, nargs, isva, env...)
480+
Core.Compiler.ir_to_codeinf!(src, ir)
510481
end
511482

512-
function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
513-
@nospecialize(sig), @nospecialize(rt),
514-
nargs::Int, isva::Bool, @nospecialize env...;
515-
mod::Module=@__MODULE__,
516-
file::Union{Nothing,Symbol}=nothing, line::Int=0)
517-
# create a method (like `jl_make_opaque_closure_method`)
483+
# create a method (like `jl_make_oc_method`)
484+
function make_oc_method(nargs; file=nothing, line=0, world=GPUCompiler.tls_world_age())
518485
meth = ccall(:jl_new_method_uninit, Ref{Method}, (Any,), Main)
519486
meth.sig = Tuple
520-
meth.isva = isva # XXX: probably not supported?
521-
meth.is_for_opaque_closure = 0 # XXX: do we want this?
487+
meth.isva = false
488+
meth.is_for_opaque_closure = 0
522489
meth.name = Symbol("opaque gpu closure")
523490
meth.nargs = nargs + 1
524491
meth.file = something(file, Symbol())
525492
meth.line = line
526-
ccall(:jl_method_set_source, Nothing, (Any, Any), meth, src)
527-
528-
# look up a method instance and create a compiler job
529-
full_sig = Tuple{typeof(env), sig.parameters...}
530-
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
531-
(Any, Any, Any), meth, full_sig, Core.svec())
532-
job = CompilerJob(mi, config) # this captures the current world age
533-
Base.@atomic meth.primary_world = job.world
493+
Base.@atomic meth.primary_world = world
534494
Base.@atomic meth.deleted_world = typemax(UInt)
495+
return meth
496+
end
535497

536-
# create a code instance and store it in the cache
537-
interp = GPUCompiler.get_interpreter(job)
498+
function make_oc_codeinstance(mi::MethodInstance, src::CodeInfo; interp, world, rt)
538499
owner = Core.Compiler.cache_owner(interp)
539500
exctype = Any
540501
inferred_const = C_NULL
541502
inferred = src
542503
const_flags = Int32(0)
543-
min_world = meth.primary_world
544-
max_world = meth.deleted_world
504+
min_world = world
505+
max_world = typemax(UInt)
545506
ipo_effects = UInt32(0)
546507
effects = UInt32(0)
547508
analysis_results = nothing
548509
relocatability = UInt8(0)
549-
ci = CodeInstance(mi, owner, rt, exctype, inferred_const, inferred,
550-
const_flags, min_world, max_world, ipo_effects, effects,
551-
analysis_results, relocatability, src.debuginfo)
552-
Core.Compiler.setindex!(GPUCompiler.ci_cache(job), ci, mi)
553-
554-
id = length(GPUCompiler.deferred_codegen_jobs) + 1
555-
GPUCompiler.deferred_codegen_jobs[id] = job
556-
return OpaqueClosure{id, typeof(env), sig, rt}(env)
510+
CodeInstance(mi, owner, rt, exctype, inferred_const, inferred,
511+
const_flags, min_world, max_world, ipo_effects, effects,
512+
analysis_results, relocatability, src.debuginfo)
557513
end
558514

559515
# generated function `ccall`, working around the restriction that ccall type
@@ -587,7 +543,60 @@ end
587543
return ex
588544
end
589545

590-
# device-side call to an opaque closure
546+
# static opaque closures
547+
548+
# XXX: because we can't call functions from other CUDA modules, we effectively need to
549+
# recompile when the target function changes. this, and because of how GPUCompiler's
550+
# deferred compilation mechanism currently works, is why we have `F` as a type param.
551+
552+
# XXX: because of GPU code requiring specialized signatures, we also need to recompile
553+
# when the environment or argument types change. together with the above, this
554+
# negates much of the benefit of opaque closures.
555+
556+
# TODO: support for constructing an opaque closure from source code
557+
558+
# TODO: complete support for passing an environment. this probably requires a split into
559+
# host and device structures to, e.g., root a CuArray and pass a CuDeviceArray.
560+
561+
struct OpaqueClosure{F, E, A, R} # func, env, args, ret
562+
env::E
563+
end
564+
565+
function OpaqueClosure(ir::IRCode, @nospecialize env...;
566+
slotnames::Union{Nothing,Vector{Symbol}}=nothing)
567+
nargtypes = length(ir.argtypes)
568+
nargs = nargtypes-1
569+
sig = compute_oc_signature(ir, nargs)
570+
rt = compute_ir_rettype(ir)
571+
src = make_oc_codeinfo(ir, env...; slotnames)
572+
return create_static_oc(src, sig, rt, nargs, env...)
573+
end
574+
575+
function OpaqueClosure(src::CodeInfo, @nospecialize env...; rettype, sig, nargs)
576+
return create_static_oc(src, sig, rettype, nargs, env...)
577+
end
578+
579+
function create_static_oc(src, @nospecialize(sig), @nospecialize(rt), nargs::Int,
580+
@nospecialize env...; file=nothing, line=0)
581+
config = compiler_config(device(); kernel=false)
582+
meth = make_oc_method(nargs; file, line)
583+
584+
# look up a method instance and create a compiler job
585+
full_sig = Tuple{typeof(env), sig.parameters...}
586+
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
587+
(Any, Any, Any), meth, full_sig, Core.svec())
588+
job = CompilerJob(mi, config, meth.primary_world)
589+
590+
# create a callable object
591+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
592+
GPUCompiler.deferred_codegen_jobs[id] = job
593+
oc = OpaqueClosure{id, typeof(env), sig, rt}(env)
594+
595+
opaque_closure_jobs[job] = (; oc, src, rt)
596+
return oc
597+
end
598+
599+
# device-side call
591600
(oc::OpaqueClosure)(args...) = call(oc, args...)
592601
## NOTE: split into two to make `SciML.isinplace(oc)` work.
593602
## it also resembles how kernels are called.
@@ -597,3 +606,87 @@ end
597606
#ccall(ptr, R, (A...), args...)
598607
generated_ccall(ptr, R, A, args...)
599608
end
609+
610+
# dynamic opaque closures
611+
612+
const jit_opaque_closures = Dict()
613+
614+
struct JITOpaqueClosure{B, T}
615+
builder::B
616+
tfunc::T
617+
618+
function JITOpaqueClosure(builder, tfunc=Returns(nothing); nargs)
619+
# the device and world are captured at closure construction time, but we only need
620+
# them when creating the CompilerJob. as we cannot simply encode them in the
621+
# JITOpaqueClosure object, we store them in a global dictionary instead.
622+
config = compiler_config(device(); kernel=false)
623+
meth = make_oc_method(nargs)
624+
625+
# create a callable object
626+
oc = new{typeof(builder), typeof(tfunc)}(builder, tfunc)
627+
jit_opaque_closures[typeof(oc)] = (; env=(), meth, config, oc)
628+
629+
return oc
630+
end
631+
end
632+
633+
# device-side call
634+
function (oc::JITOpaqueClosure)(args...)
635+
rt = oc.tfunc(map(Core.Typeof, args)...)
636+
call(oc, rt, args...)
637+
end
638+
@inline @generated function call(oct::JITOpaqueClosure{B,T}, ::Type{R}, args...) where {B,T,R}
639+
rt = R
640+
(; env, meth, config, oc) = jit_opaque_closures[oct]
641+
642+
# look up a method instance and create a compiler job
643+
full_sig = Tuple{typeof(env), args...}
644+
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
645+
(Any, Any, Any), meth, full_sig, Core.svec())
646+
job = CompilerJob(mi, config, meth.primary_world)
647+
opaque_closure_jobs[job] = (; oc, args, rt)
648+
649+
# generate a deferred compilation call
650+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
651+
GPUCompiler.deferred_codegen_jobs[id] = job
652+
quote
653+
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
654+
assume(ptr != C_NULL)
655+
#ccall(ptr, R, (A...), args...)
656+
generated_ccall(ptr, $rt, $(Tuple{args...}), args...)
657+
end
658+
end
659+
660+
# compilation of opaque closures
661+
662+
const opaque_closure_jobs = Dict{CompilerJob,Any}()
663+
664+
function GPUCompiler.prepare_job!(@nospecialize(job::CUDACompilerJob))
665+
if haskey(opaque_closure_jobs, job)
666+
rt = opaque_closure_jobs[job].rt
667+
oc = opaque_closure_jobs[job].oc
668+
if oc isa JITOpaqueClosure
669+
args = opaque_closure_jobs[job].args
670+
nargs = length(args)
671+
672+
src = oc.builder(args...)
673+
if src isa IRCode
674+
nargtypes = length(src.argtypes)
675+
nargs = nargtypes-1
676+
sig = compute_oc_signature(src, nargs)
677+
@assert compute_ir_rettype(src) == rt "Inferred return type does not match the provided return type"
678+
src = make_oc_codeinfo(src)
679+
end
680+
else
681+
src = opaque_closure_jobs[job].src
682+
end
683+
@assert src isa CodeInfo
684+
685+
# create a code instance and store it in the cache
686+
interp = GPUCompiler.get_interpreter(job)
687+
ci = make_oc_codeinstance(job.source, src; interp, job.world, rt)
688+
Core.Compiler.setindex!(GPUCompiler.ci_cache(job), ci, job.source)
689+
end
690+
691+
return
692+
end

test/core/execution.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ end
10991099
if VERSION >= v"1.12-"
11001100
@testset "opaque closures" begin
11011101

1102-
# basic closure, constructed from IRCode
1102+
# static closure, constructed from IRCode
11031103
let
11041104
ir, rettyp = only(Base.code_ircode(+, (Int, Int)))
11051105
oc = CUDA.OpaqueClosure(ir)
@@ -1118,7 +1118,7 @@ let
11181118
@test Array(c)[] == 3
11191119
end
11201120

1121-
# basic closure, constructed from CodeInfo
1121+
# static closure, constructed from CodeInfo
11221122
let
11231123
ir, rettype = only(Base.code_typed(*, (Int, Int, Int)))
11241124
oc = CUDA.OpaqueClosure(ir; sig=Tuple{Int,Int,Int}, rettype, nargs=3)
@@ -1138,6 +1138,43 @@ let
11381138
@test Array(d)[] == 24
11391139
end
11401140

1141+
# dynamic closure, constructing IRCode based on argument types
1142+
let
1143+
tfunc(arg1, arg2) = Core.Compiler.return_type(+, Tuple{arg1,arg2})
1144+
function builder(arg1, arg2)
1145+
ir, rettyp = only(Base.code_ircode(+, (arg1, arg2)))
1146+
return ir
1147+
end
1148+
1149+
oc = CUDA.JITOpaqueClosure(builder, tfunc; nargs=2)
1150+
1151+
function kernel(oc, c, a, b)
1152+
i = threadIdx().x
1153+
@inbounds c[i] = oc(a[i], b[i])
1154+
return
1155+
end
1156+
1157+
let
1158+
c = CuArray([0])
1159+
a = CuArray([1])
1160+
b = CuArray([2])
1161+
1162+
@cuda threads=1 kernel(oc, c, a, b)
1163+
1164+
@test Array(c)[] == 3
1165+
end
1166+
1167+
let
1168+
c = CuArray([3f0])
1169+
a = CuArray([4f0])
1170+
b = CuArray([5f0])
1171+
1172+
@cuda threads=1 kernel(oc, c, a, b)
1173+
1174+
@test Array(c)[] == 9f0
1175+
end
1176+
end
1177+
11411178
end
11421179
end
11431180

0 commit comments

Comments
 (0)