diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 4ec187f242..f0f6617318 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -1215,22 +1215,23 @@ end @inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T} batchsize = last(size(strided)) stride = prod(size(strided)[1:end-1]) - + #ptrs = [pointer(strided, (i-1)*stride + 1) for i in 1:batchsize] + # create the array on the GPU to avoid synchronous copies and support larger batch sizes ptrs = CuArray{CuPtr{T}}(undef, batchsize) + function kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T} + index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x + stride = gridDim().x * blockDim().x + for i in index:stride:length(ptrs) + @inbounds ptrs[i] = + reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32)) + end + return + end nblocks = cld(batchsize, 256) - @cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride) + @cuda threads = 256 blocks = nblocks kernel!(ptrs, strided, stride) return ptrs end -function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T} - index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x - stride = gridDim().x * blockDim().x - for i in index:stride:length(ptrs) - ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32)) - end - return nothing -end - ## (GE) general matrix-matrix multiplication grouped batched for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32), (:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))