Skip to content

Commit

Permalink
Add wrapper for gemmBatchedEx! (#1975)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
lpawela and maleadt authored Jun 29, 2023
1 parent 315c80e commit abd569e
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
88 changes: 87 additions & 1 deletion lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ function gemmEx!(transA::Char, transB::Char,
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
throw(DimensionMismatch(""))
throw(DimensionMismatch("A has dimension $(size(A)), B has dimension $(size(B)) and C has dimension $(size(C))"))
end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
Expand All @@ -909,6 +909,91 @@ function gemmEx!(transA::Char, transB::Char,
C
end

function gemmBatchedEx!(transA::Char, transB::Char,
@nospecialize(alpha::Number),
@nospecialize(A::Vector{<:StridedCuVecOrMat}),
@nospecialize(B::Vector{<:StridedCuVecOrMat}),
@nospecialize(beta::Number),
@nospecialize(C::Vector{<:StridedCuVecOrMat});
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
if length(A) != length(B) || length(A) != length(C)
throw(DimensionMismatch("Lengths of inputs must be the same"))
end
for (i, (As,Bs,Cs)) in enumerate(zip(A,B,C))
m = size(As, transA == 'N' ? 1 : 2)
k = size(As, transA == 'N' ? 2 : 1)
n = size(Bs, transB == 'N' ? 2 : 1)
if m != size(Cs,1) || n != size(Cs,2) || k != size(Bs, transB == 'N' ? 1 : 2)
throw(DimensionMismatch("Input $i: A has dimension $(size(As)), B has dimension $(size(Bs)), C has dimension $(size(Cs))"))
end
end
m = size(A[1], transA == 'N' ? 1 : 2)
k = size(A[1], transA == 'N' ? 2 : 1)
n = size(B[1], transB == 'N' ? 2 : 1)
lda = max(1,stride(A[1],2))
ldb = max(1,stride(B[1],2))
ldc = max(1,stride(C[1],2))
computeType = gemmExComputeType(eltype(A[1]), eltype(B[1]), eltype(C[1]), m, k, n)
isnothing(computeType) &&
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
computeT = juliaStorageType(eltype(C[1]), computeType)
Aptrs = unsafe_batch(A)
Bptrs = unsafe_batch(B)
Cptrs = unsafe_batch(C)
if version() >= v"11.0"
# with CUDA 11, the compute type encodes the math mode.
cublasGemmBatchedEx(handle(), transA, transB, m, n, k, Ref{computeT}(alpha), Aptrs, eltype(A[1]), lda, Bptrs,
eltype(B[1]), ldb, Ref{computeT}(beta), Cptrs, eltype(C[1]), ldc, length(A), computeType, algo)
else
error("Not implemented for CUDA 11 and below.")
end
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
unsafe_free!(Aptrs)

C
end

function gemmStridedBatchedEx!(transA::Char, transB::Char,
@nospecialize(alpha::Number),
@nospecialize(A::AbstractArray{Ta, 3}),
@nospecialize(B::AbstractArray{Tb, 3}),
@nospecialize(beta::Number),
@nospecialize(C::AbstractArray{Tc, 3});
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
end
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
throw(DimensionMismatch("A has dimension $(size(A)), B has dimension $(size(B)), C has dimension $(size(C))"))
end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))

strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(C, 3)

computeType = gemmExComputeType(eltype(A), eltype(B), eltype(C), m, k, n)
isnothing(computeType) &&
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
computeT = juliaStorageType(eltype(C), computeType)
if version() >= v"11.0"
# with CUDA 11, the compute type encodes the math mode.
cublasGemmStridedBatchedEx(handle(), transA, transB, m, n, k, Ref{computeT}(alpha), A, eltype(A), lda, strideA,
B, eltype(B), ldb, strideB, Ref{computeT}(beta), C, eltype(C), ldc, strideC,
batchCount, computeType, algo)
else
error("Not implemented for CUDA 11 and below.")
end
C
end

# create a batch of pointers in device memory from a batch of device arrays
@inline function unsafe_batch(batch::Vector{<:CuArray{T}}) where {T}
ptrs = pointer.(batch)
Expand Down Expand Up @@ -969,6 +1054,7 @@ for (fname, elty) in
end
end
end

function gemm_batched(transA::Char, transB::Char, alpha::Number,
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
C = CuMatrix{T}[similar(B[1], (size(A[1], transA == 'N' ? 1 : 2),size(B[1], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
Expand Down
26 changes: 24 additions & 2 deletions test/libraries/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1573,13 +1573,25 @@ end
@testset "gemm_batched" begin
bd_C = CUBLAS.gemm_batched('N','N',bd_A,bd_B)
for i in 1:length(bA)
bC = bA[i]*bB[i]
bC[i] = bA[i]*bB[i]
h_C = Array(bd_C[i])
@test bC h_C
@test bC[i] h_C
end
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
end

@testset "gemmBatchedEx!" begin
# C = (alpha*A)*B + beta*C
CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_B,beta,bd_C)
for i in 1:length(bd_C)
bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i]
h_C = Array(bd_C[i])
#compare
@test bC[i] h_C
end
@test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
end

nbatch = 10
bA = rand(elty, m, k, nbatch)
bB = rand(elty, k, n, nbatch)
Expand All @@ -1601,6 +1613,16 @@ end
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemmStridedBatchedEx!" begin
CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_C)
for i in 1:nbatch
bC[:, :, i] = (alpha * bA[:, :, i]) * bB[:, :, i] + beta * bC[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemm_strided_batched" begin
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B)

Expand Down

0 comments on commit abd569e

Please sign in to comment.