Skip to content

Commit

Permalink
Clean-up.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jan 8, 2025
1 parent 6d9b85d commit 464b2bf
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1215,22 +1215,23 @@ end
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
batchsize = last(size(strided))
stride = prod(size(strided)[1:end-1])

#ptrs = [pointer(strided, (i-1)*stride + 1) for i in 1:batchsize]
# create the array on the GPU to avoid synchronous copies and support larger batch sizes
ptrs = CuArray{CuPtr{T}}(undef, batchsize)
function kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T}
index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
stride = gridDim().x * blockDim().x
for i in index:stride:length(ptrs)
@inbounds ptrs[i] =
reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32))
end
return
end
nblocks = cld(batchsize, 256)
@cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride)
@cuda threads = 256 blocks = nblocks kernel!(ptrs, strided, stride)
return ptrs
end

function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T}
index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
stride = gridDim().x * blockDim().x
for i in index:stride:length(ptrs)
ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32))
end
return nothing
end

## (GE) general matrix-matrix multiplication grouped batched
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
Expand Down

0 comments on commit 464b2bf

Please sign in to comment.