Skip to content

Commit

Permalink
Generalise address logic
Browse files Browse the repository at this point in the history
  • Loading branch information
THargreaves committed Jan 7, 2025
1 parent 18a4652 commit 6d9b85d
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1215,20 +1215,18 @@ end
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
batchsize = last(size(strided))
stride = prod(size(strided)[1:end-1])
base_address = UInt64(pointer(strided))
offset = Base.elsize(strided) * stride

ptrs = CuArray{CuPtr{T}}(undef, batchsize)
nblocks = cld(batchsize, 1024)
@cuda threads = 1024 blocks = nblocks create_ptrs_kernel!(ptrs, base_address, offset)
nblocks = cld(batchsize, 256)
@cuda threads = 256 blocks = nblocks create_ptrs_kernel!(ptrs, strided, stride)
return ptrs
end

function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, base_address, offset) where {T}
function create_ptrs_kernel!(ptrs::CuDeviceArray{T}, A, batch_stride) where {T}
index = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
stride = gridDim().x * blockDim().x
for i in index:stride:length(ptrs)
ptrs[i] = CuPtr{T}(base_address + (i - 1i32) * offset)
ptrs[i] = reinterpret(CuPtr{T}, pointer(A, (i - 1i32) * batch_stride + 1i32))
end
return nothing
end
Expand Down

0 comments on commit 6d9b85d

Please sign in to comment.