Skip to content

Commit c2786dd

Browse files
KA without cuda backend (#670)
* KA without cuda backend * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e9471bd commit c2786dd

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

ext/ReactantCUDAExt.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function threads_to_workgroupsize(threads, ndrange)
272272
end
273273
end
274274

275-
function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
275+
function ka_with_reactant(ndrange, workgroupsize, obj, args...)
276276
backend = KA.backend(obj)
277277

278278
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(
@@ -325,6 +325,14 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi
325325
return nothing
326326
end
327327

328+
Reactant.@reactant_overlay @noinline function (obj::KA.Kernel{ReactantBackend})(
329+
args...; ndrange=nothing, workgroupsize=nothing
330+
)
331+
return Reactant.call_with_reactant(
332+
ka_with_reactant, ndrange, workgroupsize, obj, args...
333+
)
334+
end
335+
328336
Adapt.adapt_storage(to::KA.ConstAdaptor, a::CuTracedArray) = Base.Experimental.Const(a)
329337

330338
function recudaconvert(arg)

ext/ReactantKernelAbstractionsExt.jl

+8
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,12 @@ function KA.priority!(::ReactantBackend, prio::Symbol)
8383
return nothing
8484
end
8585

86+
function tokw(ndrange, workgroupsize, obj, args...)
87+
@inline obj(args...; ndrange, workgroupsize)
88+
end
89+
90+
function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
91+
@jit tokw(ndrange, workgroupsize, obj, args...)
92+
end
93+
8694
end

0 commit comments

Comments
 (0)