435
435
using Core. Compiler: IRCode
436
436
using Core: CodeInfo, MethodInstance, CodeInstance, LineNumberNode
437
437
438
- struct OpaqueClosure{F, E, A, R} # func, env, args, ret
439
- env:: E
440
- end
441
-
442
- # XXX : because we can't call functions from other CUDA modules, we effectively need to
443
- # recompile when the target function changes. this, and because of how GPUCompiler's
444
- # deferred compilation mechanism currently works, is why we have `F` as a type param.
445
-
446
- # XXX : because of GPU code requiring specialized signatures, we also need to recompile
447
- # when the environment or argument types change. together with the above, this
448
- # negates much of the benefit of opaque closures.
449
-
450
- # TODO : support for constructing an opaque closure from source code
451
-
452
- # TODO : complete support for passing an environment. this probably requires a split into
453
- # host and device structures to, e.g., root a CuArray and pass a CuDeviceArray.
438
+ # helpers
454
439
455
440
function compute_ir_rettype (ir:: IRCode )
456
441
rt = Union{}
@@ -463,32 +448,25 @@ function compute_ir_rettype(ir::IRCode)
463
448
return Core. Compiler. widenconst (rt)
464
449
end
465
450
466
- function compute_oc_signature (ir:: IRCode , nargs:: Int , isva :: Bool )
451
+ function compute_oc_signature (ir:: IRCode , nargs:: Int )
467
452
argtypes = Vector {Any} (undef, nargs)
468
453
for i = 1 : nargs
469
454
argtypes[i] = Core. Compiler. widenconst (ir. argtypes[i+ 1 ])
470
455
end
471
- if isva
472
- lastarg = pop! (argtypes)
473
- if lastarg <: Tuple
474
- append! (argtypes, lastarg. parameters)
475
- else
476
- push! (argtypes, Vararg{Any})
477
- end
478
- end
479
456
return Tuple{argtypes... }
480
457
end
481
458
482
- function OpaqueClosure (ir:: IRCode , @nospecialize env... ;
483
- isva:: Bool = false ,
484
- slotnames:: Union{Nothing,Vector{Symbol}} = nothing )
459
+ function make_oc_codeinfo (ir:: IRCode , @nospecialize env... ; slotnames= nothing )
485
460
# NOTE: we need ir.argtypes[1] == typeof(env)
486
461
ir = Core. Compiler. copy (ir)
487
- # if the user didn't specify a definition MethodInstance or filename Symbol to use for the debuginfo, set a filename now
488
- ir. debuginfo. def === nothing && (ir. debuginfo. def = :var"generated IR for OpaqueClosure" )
462
+ # if the user didn't specify a definition MethodInstance or filename Symbol to use
463
+ # for the debuginfo, set a filename now
464
+ if ir. debuginfo. def === nothing
465
+ ir. debuginfo. def = Symbol (" IR for opaque gpu closure" )
466
+ end
489
467
nargtypes = length (ir. argtypes)
490
468
nargs = nargtypes- 1
491
- sig = compute_oc_signature (ir, nargs, isva )
469
+ sig = compute_oc_signature (ir, nargs)
492
470
rt = compute_ir_rettype (ir)
493
471
src = ccall (:jl_new_code_info_uninit , Ref{CodeInfo}, ())
494
472
if slotnames === nothing
@@ -499,61 +477,39 @@ function OpaqueClosure(ir::IRCode, @nospecialize env...;
499
477
end
500
478
src. slotflags = Base. fill (zero (UInt8), nargtypes)
501
479
src. slottypes = copy (ir. argtypes)
502
- src = Core. Compiler. ir_to_codeinf! (src, ir)
503
- config = compiler_config (device (); kernel= false )
504
- return generate_opaque_closure (config, src, sig, rt, nargs, isva, env... )
505
- end
506
-
507
- function OpaqueClosure (src:: CodeInfo , @nospecialize env... ; rettype, sig, nargs, isva= false )
508
- config = compiler_config (device (); kernel= false )
509
- return generate_opaque_closure (config, src, sig, rettype, nargs, isva, env... )
480
+ Core. Compiler. ir_to_codeinf! (src, ir)
510
481
end
511
482
512
- function generate_opaque_closure (config:: CompilerConfig , src:: CodeInfo ,
513
- @nospecialize (sig), @nospecialize (rt),
514
- nargs:: Int , isva:: Bool , @nospecialize env... ;
515
- mod:: Module = @__MODULE__ ,
516
- file:: Union{Nothing,Symbol} = nothing , line:: Int = 0 )
517
- # create a method (like `jl_make_opaque_closure_method`)
483
+ # create a method (like `jl_make_oc_method`)
484
+ function make_oc_method (nargs; file= nothing , line= 0 , world= GPUCompiler. tls_world_age ())
518
485
meth = ccall (:jl_new_method_uninit , Ref{Method}, (Any,), Main)
519
486
meth. sig = Tuple
520
- meth. isva = isva # XXX : probably not supported?
521
- meth. is_for_opaque_closure = 0 # XXX : do we want this?
487
+ meth. isva = false
488
+ meth. is_for_opaque_closure = 0
522
489
meth. name = Symbol (" opaque gpu closure" )
523
490
meth. nargs = nargs + 1
524
491
meth. file = something (file, Symbol ())
525
492
meth. line = line
526
- ccall (:jl_method_set_source , Nothing, (Any, Any), meth, src)
527
-
528
- # look up a method instance and create a compiler job
529
- full_sig = Tuple{typeof (env), sig. parameters... }
530
- mi = ccall (:jl_specializations_get_linfo , Ref{MethodInstance},
531
- (Any, Any, Any), meth, full_sig, Core. svec ())
532
- job = CompilerJob (mi, config) # this captures the current world age
533
- Base. @atomic meth. primary_world = job. world
493
+ Base. @atomic meth. primary_world = world
534
494
Base. @atomic meth. deleted_world = typemax (UInt)
495
+ return meth
496
+ end
535
497
536
- # create a code instance and store it in the cache
537
- interp = GPUCompiler. get_interpreter (job)
498
+ function make_oc_codeinstance (mi:: MethodInstance , src:: CodeInfo ; interp, world, rt)
538
499
owner = Core. Compiler. cache_owner (interp)
539
500
exctype = Any
540
501
inferred_const = C_NULL
541
502
inferred = src
542
503
const_flags = Int32 (0 )
543
- min_world = meth . primary_world
544
- max_world = meth . deleted_world
504
+ min_world = world
505
+ max_world = typemax (UInt)
545
506
ipo_effects = UInt32 (0 )
546
507
effects = UInt32 (0 )
547
508
analysis_results = nothing
548
509
relocatability = UInt8 (0 )
549
- ci = CodeInstance (mi, owner, rt, exctype, inferred_const, inferred,
550
- const_flags, min_world, max_world, ipo_effects, effects,
551
- analysis_results, relocatability, src. debuginfo)
552
- Core. Compiler. setindex! (GPUCompiler. ci_cache (job), ci, mi)
553
-
554
- id = length (GPUCompiler. deferred_codegen_jobs) + 1
555
- GPUCompiler. deferred_codegen_jobs[id] = job
556
- return OpaqueClosure {id, typeof(env), sig, rt} (env)
510
+ CodeInstance (mi, owner, rt, exctype, inferred_const, inferred,
511
+ const_flags, min_world, max_world, ipo_effects, effects,
512
+ analysis_results, relocatability, src. debuginfo)
557
513
end
558
514
559
515
# generated function `ccall`, working around the restriction that ccall type
587
543
return ex
588
544
end
589
545
590
- # device-side call to an opaque closure
546
+ # static opaque closures
547
+
548
+ # XXX : because we can't call functions from other CUDA modules, we effectively need to
549
+ # recompile when the target function changes. this, and because of how GPUCompiler's
550
+ # deferred compilation mechanism currently works, is why we have `F` as a type param.
551
+
552
+ # XXX : because of GPU code requiring specialized signatures, we also need to recompile
553
+ # when the environment or argument types change. together with the above, this
554
+ # negates much of the benefit of opaque closures.
555
+
556
+ # TODO : support for constructing an opaque closure from source code
557
+
558
+ # TODO : complete support for passing an environment. this probably requires a split into
559
+ # host and device structures to, e.g., root a CuArray and pass a CuDeviceArray.
560
+
561
+ struct OpaqueClosure{F, E, A, R} # func, env, args, ret
562
+ env:: E
563
+ end
564
+
565
+ function OpaqueClosure (ir:: IRCode , @nospecialize env... ;
566
+ slotnames:: Union{Nothing,Vector{Symbol}} = nothing )
567
+ nargtypes = length (ir. argtypes)
568
+ nargs = nargtypes- 1
569
+ sig = compute_oc_signature (ir, nargs)
570
+ rt = compute_ir_rettype (ir)
571
+ src = make_oc_codeinfo (ir, env... ; slotnames)
572
+ return create_static_oc (src, sig, rt, nargs, env... )
573
+ end
574
+
575
+ function OpaqueClosure (src:: CodeInfo , @nospecialize env... ; rettype, sig, nargs)
576
+ return create_static_oc (src, sig, rettype, nargs, env... )
577
+ end
578
+
579
+ function create_static_oc (src, @nospecialize (sig), @nospecialize (rt), nargs:: Int ,
580
+ @nospecialize env... ; file= nothing , line= 0 )
581
+ config = compiler_config (device (); kernel= false )
582
+ meth = make_oc_method (nargs; file, line)
583
+
584
+ # look up a method instance and create a compiler job
585
+ full_sig = Tuple{typeof (env), sig. parameters... }
586
+ mi = ccall (:jl_specializations_get_linfo , Ref{MethodInstance},
587
+ (Any, Any, Any), meth, full_sig, Core. svec ())
588
+ job = CompilerJob (mi, config, meth. primary_world)
589
+
590
+ # create a callable object
591
+ id = length (GPUCompiler. deferred_codegen_jobs) + 1
592
+ GPUCompiler. deferred_codegen_jobs[id] = job
593
+ oc = OpaqueClosure {id, typeof(env), sig, rt} (env)
594
+
595
+ opaque_closure_jobs[job] = (; oc, src, rt)
596
+ return oc
597
+ end
598
+
599
+ # device-side call
591
600
(oc:: OpaqueClosure )(args... ) = call (oc, args... )
592
601
# # NOTE: split into two to make `SciML.isinplace(oc)` work.
593
602
# # it also resembles how kernels are called.
597
606
# ccall(ptr, R, (A...), args...)
598
607
generated_ccall (ptr, R, A, args... )
599
608
end
609
+
610
+ # dynamic opaque closures
611
+
612
+ const jit_opaque_closures = Dict ()
613
+
614
+ struct JITOpaqueClosure{B, T}
615
+ builder:: B
616
+ tfunc:: T
617
+
618
+ function JITOpaqueClosure (builder, tfunc= Returns (nothing ); nargs)
619
+ # the device and world are captured at closure construction time, but we only need
620
+ # them when creating the CompilerJob. as we cannot simply encode them in the
621
+ # JITOpaqueClosure object, we store them in a global dictionary instead.
622
+ config = compiler_config (device (); kernel= false )
623
+ meth = make_oc_method (nargs)
624
+
625
+ # create a callable object
626
+ oc = new {typeof(builder), typeof(tfunc)} (builder, tfunc)
627
+ jit_opaque_closures[typeof (oc)] = (; env= (), meth, config, oc)
628
+
629
+ return oc
630
+ end
631
+ end
632
+
633
+ # device-side call
634
+ function (oc:: JITOpaqueClosure )(args... )
635
+ rt = oc. tfunc (map (Core. Typeof, args)... )
636
+ call (oc, rt, args... )
637
+ end
638
+ @inline @generated function call (oct:: JITOpaqueClosure{B,T} , :: Type{R} , args... ) where {B,T,R}
639
+ rt = R
640
+ (; env, meth, config, oc) = jit_opaque_closures[oct]
641
+
642
+ # look up a method instance and create a compiler job
643
+ full_sig = Tuple{typeof (env), args... }
644
+ mi = ccall (:jl_specializations_get_linfo , Ref{MethodInstance},
645
+ (Any, Any, Any), meth, full_sig, Core. svec ())
646
+ job = CompilerJob (mi, config, meth. primary_world)
647
+ opaque_closure_jobs[job] = (; oc, args, rt)
648
+
649
+ # generate a deferred compilation call
650
+ id = length (GPUCompiler. deferred_codegen_jobs) + 1
651
+ GPUCompiler. deferred_codegen_jobs[id] = job
652
+ quote
653
+ ptr = ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
654
+ assume (ptr != C_NULL )
655
+ # ccall(ptr, R, (A...), args...)
656
+ generated_ccall (ptr, $ rt, $ (Tuple{args... }), args... )
657
+ end
658
+ end
659
+
660
+ # compilation of opaque closures
661
+
662
+ const opaque_closure_jobs = Dict {CompilerJob,Any} ()
663
+
664
+ function GPUCompiler. prepare_job! (@nospecialize (job:: CUDACompilerJob ))
665
+ if haskey (opaque_closure_jobs, job)
666
+ rt = opaque_closure_jobs[job]. rt
667
+ oc = opaque_closure_jobs[job]. oc
668
+ if oc isa JITOpaqueClosure
669
+ args = opaque_closure_jobs[job]. args
670
+ nargs = length (args)
671
+
672
+ src = oc. builder (args... )
673
+ if src isa IRCode
674
+ nargtypes = length (src. argtypes)
675
+ nargs = nargtypes- 1
676
+ sig = compute_oc_signature (src, nargs)
677
+ @assert compute_ir_rettype (src) == rt " Inferred return type does not match the provided return type"
678
+ src = make_oc_codeinfo (src)
679
+ end
680
+ else
681
+ src = opaque_closure_jobs[job]. src
682
+ end
683
+ @assert src isa CodeInfo
684
+
685
+ # create a code instance and store it in the cache
686
+ interp = GPUCompiler. get_interpreter (job)
687
+ ci = make_oc_codeinstance (job. source, src; interp, job. world, rt)
688
+ Core. Compiler. setindex! (GPUCompiler. ci_cache (job), ci, job. source)
689
+ end
690
+
691
+ return
692
+ end
0 commit comments