Skip to content

Commit 4d85f27

Browse files
authored
[CUSPARSE] Support CuSparseMatrixBSR in the generic mm! (#2639)
1 parent 031d7b9 commit 4d85f27

File tree

8 files changed

+97
-111
lines changed

8 files changed

+97
-111
lines changed

lib/cusparse/generic.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
# generic APIs
22

33
export gather!, scatter!, axpby!, rot!
4-
export vv!, sv!, sm!, gemv, gemm, gemm!, sddmm!
4+
export vv!, sv!, sm!, mv!, mm!, gemv, gemm, gemm!, sddmm!
55
export bmm!
66

7+
"""
8+
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
9+
10+
Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
11+
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
12+
`X` and `Y` are dense vectors.
13+
"""
14+
function mv! end
15+
16+
"""
17+
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
18+
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuMatrix, B::Union{CuSparseMatrixCSC,CuSparseMatrixCSR,CuSparseMatrixCOO}, beta::Number, C::CuMatrix, index::SparseChar)
19+
20+
Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
21+
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
22+
"""
23+
function mm! end
24+
725
## API functions
826

927
function sparsetodense(A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, index::SparseChar, algo::cusparseSparseToDenseAlg_t=CUSPARSE_SPARSETODENSE_ALG_DEFAULT) where {T}
@@ -191,9 +209,11 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},C
191209
return Y
192210
end
193211

194-
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}},
212+
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix{T},
195213
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}
196214

215+
(A isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.5.1") && throw(ErrorException("This operation is not supported by the current CUDA version."))
216+
197217
# Support transa = 'C' and `transb = 'C' for real matrices
198218
transa = T <: Real && transa == 'C' ? 'T' : transa
199219
transb = T <: Real && transb == 'C' ? 'T' : transb
@@ -235,10 +255,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuS
235255
# cusparseCsrSetStridedBatch(obj, batchsize, 0, nnz(A))
236256
# end
237257

238-
# Set default buffer for small matrices (10000 chosen arbitrarly)
258+
# Set default buffer for small matrices (1000 chosen arbitrarly)
239259
# Otherwise tries to allocate 120TB of memory (see #2296)
240260
function bufferSize()
241-
out = Ref{Csize_t}(10000)
261+
out = Ref{Csize_t}(1000)
242262
cusparseSpMM_bufferSize(
243263
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
244264
descC, T, algo, out)
@@ -274,7 +294,6 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
274294
throw(ErrorException("Batched dense-matrix times batched sparse-matrix (bmm!) requires a CUSPARSE version ≥ 11.7.2 (yours: $(CUSPARSE.version()))."))
275295
end
276296

277-
278297
# Support transa = 'C' and `transb = 'C' for real matrices
279298
transa = T <: Real && transa == 'C' ? 'T' : transa
280299
transb = T <: Real && transb == 'C' ? 'T' : transb
@@ -313,10 +332,10 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
313332
strideC = stride(C, 3)
314333
cusparseDnMatSetStridedBatch(descC, b, strideC)
315334

316-
# Set default buffer for small matrices (10000 chosen arbitrarly)
335+
# Set default buffer for small matrices (1000 chosen arbitrarly)
317336
# Otherwise tries to allocate 120TB of memory (see #2296)
318337
function bufferSize()
319-
out = Ref{Csize_t}(10000)
338+
out = Ref{Csize_t}(1000)
320339
cusparseSpMM_bufferSize(
321340
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
322341
descC, T, algo, out)
@@ -337,10 +356,11 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
337356
end
338357

339358
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T},
340-
B::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}},
341-
beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}
359+
B::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, beta::Number,
360+
C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}
342361

343362
CUSPARSE.version() < v"11.7.4" && throw(ErrorException("This operation is not supported by the current CUDA version."))
363+
344364
# Support transa = 'C' and `transb = 'C' for real matrices
345365
transa = T <: Real && transa == 'C' ? 'T' : transa
346366
transb = T <: Real && transb == 'C' ? 'T' : transb
@@ -373,10 +393,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMa
373393
descB = CuSparseMatrixDescriptor(B, index, transposed=true)
374394
descC = CuDenseMatrixDescriptor(C, transposed=true)
375395

376-
# Set default buffer for small matrices (10000 chosen arbitrarly)
396+
# Set default buffer for small matrices (1000 chosen arbitrarly)
377397
# Otherwise tries to allocate 120TB of memory (see #2296)
378398
function bufferSize()
379-
out = Ref{Csize_t}(10000)
399+
out = Ref{Csize_t}(1000)
380400
cusparseSpMM_bufferSize(
381401
handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta),
382402
descC, T, algo, out)
@@ -736,9 +756,10 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
736756
end
737757

738758
function sddmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::DenseCuMatrix{T},
739-
beta::Number, C::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSDDMMAlg_t=CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
759+
beta::Number, C::Union{CuSparseMatrixCSR{T},CuSparseMatrixBSR{T}}, index::SparseChar, algo::cusparseSDDMMAlg_t=CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
740760

741761
CUSPARSE.version() < v"11.4.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
762+
(C isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.1.0") && throw(ErrorException("This operation is not supported by the current CUDA version."))
742763

743764
# Support transa = 'C' and `transb = 'C' for real matrices
744765
transa = T <: Real && transa == 'C' ? 'T' : transa

lib/cusparse/level2.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
# sparse linear algebra functions that perform operations between sparse matrices and dense
22
# vectors
33

4-
export mv!, sv2!, sv2, gemvi!
5-
6-
"""
7-
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
8-
9-
Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
10-
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
11-
`X` and `Y` are dense vectors.
12-
"""
13-
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
4+
export sv2!, sv2, gemvi!
145

156
for (fname,elty) in ((:cusparseSbsrmv, :Float32),
167
(:cusparseDbsrmv, :Float64),

lib/cusparse/level3.jl

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,7 @@
11
# sparse linear algebra functions that perform operations between sparse and (usually tall)
22
# dense matrices
33

4-
export mm!, sm2!, sm2
5-
6-
"""
7-
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
8-
9-
Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
10-
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
11-
`B` and `C` are dense matrices.
12-
"""
13-
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
14-
15-
# bsrmm
16-
for (fname,elty) in ((:cusparseSbsrmm, :Float32),
17-
(:cusparseDbsrmm, :Float64),
18-
(:cusparseCbsrmm, :ComplexF32),
19-
(:cusparseZbsrmm, :ComplexF64))
20-
@eval begin
21-
function mm!(transa::SparseChar,
22-
transb::SparseChar,
23-
alpha::Number,
24-
A::CuSparseMatrixBSR{$elty},
25-
B::StridedCuMatrix{$elty},
26-
beta::Number,
27-
C::StridedCuMatrix{$elty},
28-
index::SparseChar)
29-
30-
# Support transa = 'C' and `transb = 'C' for real matrices
31-
transa = $elty <: Real && transa == 'C' ? 'T' : transa
32-
transb = $elty <: Real && transb == 'C' ? 'T' : transb
33-
34-
desc = CuMatrixDescriptor('G', 'L', 'N', index)
35-
m,k = size(A)
36-
mb = cld(m, A.blockDim)
37-
kb = cld(k, A.blockDim)
38-
n = size(C)[2]
39-
if transa == 'N' && transb == 'N'
40-
chkmmdims(B,C,k,n,m,n)
41-
elseif transa == 'N' && transb != 'N'
42-
chkmmdims(B,C,n,k,m,n)
43-
elseif transa != 'N' && transb == 'N'
44-
chkmmdims(B,C,m,n,k,n)
45-
elseif transa != 'N' && transb != 'N'
46-
chkmmdims(B,C,n,m,k,n)
47-
end
48-
ldb = max(1,stride(B,2))
49-
ldc = max(1,stride(C,2))
50-
$fname(handle(), A.dir,
51-
transa, transb, mb, n, kb, A.nnzb,
52-
alpha, desc, nonzeros(A),A.rowPtr, A.colVal,
53-
A.blockDim, B, ldb, beta, C, ldc)
54-
C
55-
end
56-
end
57-
end
4+
export sm2!, sm2
585

596
"""
607
sm2!(transa::SparseChar, transxy::SparseChar, uplo::SparseChar, diag::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, X::CuMatrix, index::SparseChar)

lib/cusparse/libcusparse.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5415,9 +5415,9 @@ end
54155415
@gcsafe_ccall libcusparse.cusparseCreateBsr(spMatDescr::Ptr{cusparseSpMatDescr_t},
54165416
brows::Int64, bcols::Int64, bnnz::Int64,
54175417
rowBlockSize::Int64, colBlockSize::Int64,
5418-
bsrRowOffsets::Ptr{Cvoid},
5419-
bsrColInd::Ptr{Cvoid},
5420-
bsrValues::Ptr{Cvoid},
5418+
bsrRowOffsets::CuPtr{Cvoid},
5419+
bsrColInd::CuPtr{Cvoid},
5420+
bsrValues::CuPtr{Cvoid},
54215421
bsrRowOffsetsType::cusparseIndexType_t,
54225422
bsrColIndType::cusparseIndexType_t,
54235423
idxBase::cusparseIndexBase_t,
@@ -5434,9 +5434,9 @@ end
54345434
brows::Int64, bcols::Int64,
54355435
bnnz::Int64, rowBlockDim::Int64,
54365436
colBlockDim::Int64,
5437-
bsrRowOffsets::Ptr{Cvoid},
5438-
bsrColInd::Ptr{Cvoid},
5439-
bsrValues::Ptr{Cvoid},
5437+
bsrRowOffsets::CuPtr{Cvoid},
5438+
bsrColInd::CuPtr{Cvoid},
5439+
bsrValues::CuPtr{Cvoid},
54405440
bsrRowOffsetsType::cusparseIndexType_t,
54415441
bsrColIndType::cusparseIndexType_t,
54425442
idxBase::cusparseIndexBase_t,

res/wrap/cusparse.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,3 +990,13 @@ needs_context = false
990990

991991
[api.cusparseSpMMOp.argtypes]
992992
2 = "CuPtr{Cvoid}"
993+
994+
[api.cusparseCreateBsr.argtypes]
995+
7 = "CuPtr{Cvoid}"
996+
8 = "CuPtr{Cvoid}"
997+
9 = "CuPtr{Cvoid}"
998+
999+
[api.cusparseCreateConstBsr.argtypes]
1000+
7 = "CuPtr{Cvoid}"
1001+
8 = "CuPtr{Cvoid}"
1002+
9 = "CuPtr{Cvoid}"

test/libraries/cusparse.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,7 @@ end
908908
alpha = rand(elty)
909909
beta = rand(elty)
910910
@testset "$(typeof(d_A))" for d_A in [CuSparseMatrixCSR(A),
911-
CuSparseMatrixCSC(A),
912-
CuSparseMatrixBSR(A, blockdim)]
911+
CuSparseMatrixCSC(A)]
913912
d_B = CuArray(B)
914913
d_C = CuArray(C)
915914
@test_throws DimensionMismatch CUSPARSE.mm!('N','T',alpha,d_A,d_B,beta,d_C,'O')

0 commit comments

Comments
 (0)