From db22eead17b78a759236b8d34e4194fd7a54c5c1 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 8 Jan 2025 08:44:48 +0100 Subject: [PATCH] Clean-up. --- lib/cublas/wrappers.jl | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 4ec187f242..1e17b3f3fe 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -1213,24 +1213,29 @@ end # create a batch of pointers in device memory from a strided device array @inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T} - batchsize = last(size(strided)) - stride = prod(size(strided)[1:end-1]) - - ptrs = CuArray{CuPtr{T}}(undef, batchsize) - nblocks = cld(batchsize, 256) - @cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride) + batch_size = last(size(strided)) + batch_stride = prod(size(strided)[1:end-1]) + #ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size] + # create the array on the GPU to avoid synchronous copies and support larger batch sizes + ptrs = CuArray{CuPtr{T}}(undef, batch_size) + function compute_pointers() + i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x + grid_stride = gridDim().x * blockDim().x + while i <= length(ptrs) + @inbounds ptrs[i] = + reinterpret(CuPtr{T}, pointer(strided, (i - 1i32) * batch_stride + 1i32)) + i += grid_stride + end + return + end + kernel = @cuda launch = false compute_pointers() + config = launch_configuration(kernel.fun) + threads = min(config.threads, batch_size) + blocks = min(config.blocks, cld(batch_size, threads)) + @cuda threads blocks compute_pointers() return ptrs end -function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T} - index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x - stride = gridDim().x * blockDim().x - for i in index:stride:length(ptrs) - ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32)) - end - return nothing -end - ## (GE) general matrix-matrix multiplication grouped batched for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32), (:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))