Skip to content

Commit

Permalink
Clean-up.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jan 8, 2025
1 parent 6d9b85d commit db22eea
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1213,24 +1213,29 @@ end

# create a batch of pointers in device memory from a strided device array
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
batchsize = last(size(strided))
stride = prod(size(strided)[1:end-1])

ptrs = CuArray{CuPtr{T}}(undef, batchsize)
nblocks = cld(batchsize, 256)
@cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride)
batch_size = last(size(strided))
batch_stride = prod(size(strided)[1:end-1])
#ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
# create the array on the GPU to avoid synchronous copies and support larger batch sizes
ptrs = CuArray{CuPtr{T}}(undef, batch_size)
function compute_pointers()
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
grid_stride = gridDim().x * blockDim().x
while i <= length(ptrs)
@inbounds ptrs[i] =
reinterpret(CuPtr{T}, pointer(strided, (i - 1i32) * batch_stride + 1i32))
i += grid_stride
end
return
end
kernel = @cuda launch = false compute_pointers()
config = launch_configuration(kernel.fun)
threads = min(config.threads, batch_size)
blocks = min(config.blocks, cld(batch_size, threads))
@cuda threads blocks compute_pointers()
return ptrs
end

function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T}
index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
stride = gridDim().x * blockDim().x
for i in index:stride:length(ptrs)
ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32))
end
return nothing
end

## (GE) general matrix-matrix multiplication grouped batched
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
Expand Down

0 comments on commit db22eea

Please sign in to comment.