From 6d9b85dbb2876b3dae63a2bb03ada515d650afce Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 7 Jan 2025 20:34:03 +0000 Subject: [PATCH] Generalise address logic --- lib/cublas/wrappers.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 0abfe6f283..4ec187f242 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -1215,20 +1215,18 @@ end @inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T} batchsize = last(size(strided)) stride = prod(size(strided)[1:end-1]) - base_address = UInt64(pointer(strided)) - offset = Base.elsize(strided) * stride ptrs = CuArray{CuPtr{T}}(undef, batchsize) - nblocks = cld(batchsize, 1024) - @cuda threads = 1024 blocks = nblocks create_ptrs_kernel!(ptrs, base_address, offset) + nblocks = cld(batchsize, 256) + @cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride) return ptrs end -function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, base_address, offset) where {T} +function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T} index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x stride = gridDim().x * blockDim().x for i in index:stride:length(ptrs) - ptrs[i] = CuPtr{T}(base_address + (i - 1i32) * offset) + ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32)) end return nothing end