Skip to content

Commit

Permalink
Merge pull request #1089 from emankov/HIPIFY
Browse files Browse the repository at this point in the history
[HIPIFY][6.0.0][rocBLAS] Support for ROCm HIP 6.0.0 - Step 21 - functions `rocblas_(hsh|hss|tst|tss)gemv_strided_batched`
  • Loading branch information
emankov authored Oct 23, 2023
2 parents 177cb8c + 0784d32 commit 6b11c61
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 16 deletions.
8 changes: 4 additions & 4 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,9 @@ sub rocSubstitutions {
subst("cublasGetVector", "rocblas_get_vector", "library");
subst("cublasGetVectorAsync", "rocblas_get_vector_async", "library");
subst("cublasHSHgemvBatched", "rocblas_hshgemv_batched", "library");
subst("cublasHSHgemvStridedBatched", "rocblas_hshgemv_strided_batched", "library");
subst("cublasHSSgemvBatched", "rocblas_hssgemv_batched", "library");
subst("cublasHSSgemvStridedBatched", "rocblas_hssgemv_strided_batched", "library");
subst("cublasHgemm", "rocblas_hgemm", "library");
subst("cublasHgemmBatched", "rocblas_hgemm_batched", "library");
subst("cublasHgemmStridedBatched", "rocblas_hgemm_strided_batched", "library");
Expand Down Expand Up @@ -1578,7 +1580,9 @@ sub rocSubstitutions {
subst("cublasStrsv", "rocblas_strsv", "library");
subst("cublasStrsv_v2", "rocblas_strsv", "library");
subst("cublasTSSgemvBatched", "rocblas_tssgemv_batched", "library");
subst("cublasTSSgemvStridedBatched", "rocblas_tssgemv_strided_batched", "library");
subst("cublasTSTgemvBatched", "rocblas_tstgemv_batched", "library");
subst("cublasTSTgemvStridedBatched", "rocblas_tstgemv_strided_batched", "library");
subst("cublasZaxpy", "rocblas_zaxpy", "library");
subst("cublasZaxpy_v2", "rocblas_zaxpy", "library");
subst("cublasZcopy", "rocblas_zcopy", "library");
Expand Down Expand Up @@ -9961,10 +9965,8 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasXerbla",
"cublasUint8gemmBias",
"cublasTSTgemvStridedBatched_64",
"cublasTSTgemvStridedBatched",
"cublasTSTgemvBatched_64",
"cublasTSSgemvStridedBatched_64",
"cublasTSSgemvStridedBatched",
"cublasTSSgemvBatched_64",
"cublasSwapEx_64",
"cublasSwapEx",
Expand Down Expand Up @@ -10096,10 +10098,8 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasHgemmStridedBatched_64",
"cublasHgemmBatched_64",
"cublasHSSgemvStridedBatched_64",
"cublasHSSgemvStridedBatched",
"cublasHSSgemvBatched_64",
"cublasHSHgemvStridedBatched_64",
"cublasHSHgemvStridedBatched",
"cublasHSHgemvBatched_64",
"cublasGetVersion_v2",
"cublasGetVersion",
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@
|`cublasDtrsm_v2_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | | | | | | | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | |`rocblas_hshgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | | | | | | | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | |`rocblas_hssgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHgemm`|7.5| | |`hipblasHgemm`|1.8.2| | | | |`rocblas_hgemm`|1.5.0| | | | |
|`cublasHgemmBatched`|9.0| | |`hipblasHgemmBatched`|3.0.0| | | | |`rocblas_hgemm_batched`|3.5.0| | | | |
Expand Down Expand Up @@ -847,11 +847,11 @@
|`cublasStrsm_v2_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | | | | | | | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | |`rocblas_tssgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | | | | | | | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | |`rocblas_tstgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasZgemm`| | | |`hipblasZgemm_v2`|6.0.0| | | |6.0.0|`rocblas_zgemm`|1.5.0| | | | |
|`cublasZgemm3m`|8.0| | | | | | | | | | | | | | |
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@
|`cublasDtrsm_v2_64`|12.0| | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvBatched_64`|12.0| | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | |`rocblas_hshgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvBatched_64`|12.0| | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | |`rocblas_hssgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasHgemm`|7.5| | |`rocblas_hgemm`|1.5.0| | | | |
|`cublasHgemmBatched`|9.0| | |`rocblas_hgemm_batched`|3.5.0| | | | |
Expand Down Expand Up @@ -847,11 +847,11 @@
|`cublasStrsm_v2_64`|12.0| | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvBatched_64`|12.0| | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | |`rocblas_tssgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvBatched_64`|12.0| | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | |`rocblas_tstgemv_strided_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasZgemm`| | | |`rocblas_zgemm`|1.5.0| | | | |
|`cublasZgemm3m`|8.0| | | | | | | | |
Expand Down
12 changes: 8 additions & 4 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,13 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {
{"cublasCgemvStridedBatched_64", {"hipblasCgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasZgemvStridedBatched", {"hipblasZgemvStridedBatched_v2", "rocblas_zgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7}},
{"cublasZgemvStridedBatched_64", {"hipblasZgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSHgemvStridedBatched", {"hipblasHSHgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSHgemvStridedBatched", {"hipblasHSHgemvStridedBatched", "rocblas_hshgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasHSHgemvStridedBatched_64", {"hipblasHSHgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSSgemvStridedBatched", {"hipblasHSSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSSgemvStridedBatched", {"hipblasHSSgemvStridedBatched", "rocblas_hssgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasHSSgemvStridedBatched_64", {"hipblasHSSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSTgemvStridedBatched", {"hipblasTSTgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSTgemvStridedBatched", {"hipblasTSTgemvStridedBatched", "rocblas_tstgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasTSTgemvStridedBatched_64", {"hipblasTSTgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSSgemvStridedBatched", {"hipblasTSSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSSgemvStridedBatched", {"hipblasTSSgemvStridedBatched", "rocblas_tssgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasTSSgemvStridedBatched_64", {"hipblasTSSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},

// SYRK
Expand Down Expand Up @@ -2100,6 +2100,10 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
{"rocblas_hssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tstgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_hshgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_hssgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tstgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tssgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
};

const std::map<llvm::StringRef, hipAPIChangedVersions> HIP_BLAS_FUNCTION_CHANGED_VER_MAP {
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ int main() {
__half* hc = 0;
// CHECK: rocblas_half* hC = 0;
__half* hC = 0;
// CHECK: rocblas_half* hx = 0;
__half* hx = 0;
// CHECK: rocblas_half* hy = 0;
__half* hy = 0;

// CHECK: rocblas_half** hAarray = 0;
__half** hAarray = 0;
Expand All @@ -274,6 +278,13 @@ int main() {
// CHECK: rocblas_half** hyarray = 0;
__half** hyarray = 0;

// CHECK: rocblas_bfloat16* bf16A = 0;
__nv_bfloat16* bf16A = 0;
// CHECK: rocblas_bfloat16* bf16x = 0;
__nv_bfloat16* bf16x = 0;
// CHECK: rocblas_bfloat16* bf16y = 0;
__nv_bfloat16* bf16y = 0;

// CHECK: rocblas_bfloat16** bf16Aarray = 0;
__nv_bfloat16** bf16Aarray = 0;
// CHECK: const rocblas_bfloat16** const bf16Aarray_const = const_cast<const rocblas_bfloat16**>(bf16Aarray);
Expand Down Expand Up @@ -1838,6 +1849,26 @@ int main() {
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_tssgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount);
blasStatus = cublasTSSgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* A, int lda, long long int strideA, const __half* x, int incx, long long int stridex, const float* beta, __half* y, int incy, long long int stridey, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hshgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride strideA, const rocblas_half* x, rocblas_int incx, rocblas_stride stridex, const float* beta, rocblas_half* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hshgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, hy, incy, stridey, batchCount);
blasStatus = cublasHSHgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, hy, incy, stridey, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* A, int lda, long long int strideA, const __half* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hssgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride strideA, const rocblas_half* x, rocblas_int incx, rocblas_stride stridex, const float* beta, float* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hssgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, &fy, incy, stridey, batchCount);
blasStatus = cublasHSSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, &fy, incy, stridey, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* A, int lda, long long int strideA, const __nv_bfloat16* x, int incx, long long int stridex, const float* beta, __nv_bfloat16* y, int incy, long long int stridey, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_tstgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* A, rocblas_int lda, rocblas_stride strideA, const rocblas_bfloat16* x, rocblas_int incx, rocblas_stride stridex, const float* beta, rocblas_bfloat16* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_tstgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, bf16y, incy, stridey, batchCount);
blasStatus = cublasTSTgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, bf16y, incy, stridey, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* A, int lda, long long int strideA, const __nv_bfloat16* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* A, rocblas_int lda, rocblas_stride strideA, const rocblas_bfloat16* x, rocblas_int incx, rocblas_stride stridex, const float* beta, float* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_tssgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, &fy, incy, stridey, batchCount);
blasStatus = cublasTSSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, &fy, incy, stridey, batchCount);
#endif

return 0;
Expand Down

0 comments on commit 6b11c61

Please sign in to comment.