diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ea60f2507..6256fa7bb 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -263,6 +263,15 @@ function Adapt.adapt_structure( ) end +function threads_to_workgroupsize(threads, ndrange) + total = 1 + return map(ndrange) do n + x = min(div(threads, total), n) + total *= x + return x + end +end + function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing) backend = KA.backend(obj)