From 967de9b1ad31dde7dc5f3bf83f233de617f2dc87 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Fri, 21 Jun 2024 15:29:04 +0100 Subject: [PATCH] [HIPIFY][BLAS][fix] Added support for the missing `hipblas(S|D)gemvStridedBatched` + Updated synthetic tests, the regenerated `hipify-perl`, and `BLAS` `CUDA2HIP` documentation --- bin/hipify-perl | 4 ++-- docs/tables/CUBLAS_API_supported_by_HIP.md | 4 ++-- docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md | 4 ++-- src/CUDA2HIP_BLAS_API_functions.cpp | 8 +++++--- .../synthetic/libraries/cublas2hipblas.cu | 2 +- .../synthetic/libraries/cublas2hipblas_v2.cu | 14 ++++++++++++++ 6 files changed, 26 insertions(+), 10 deletions(-) diff --git a/bin/hipify-perl b/bin/hipify-perl index 257fb081..accb41a5 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -3921,6 +3921,7 @@ sub simpleSubstitutions { subst("cublasDgemmStridedBatched", "hipblasDgemmStridedBatched", "library"); subst("cublasDgemm_v2", "hipblasDgemm", "library"); subst("cublasDgemv", "hipblasDgemv", "library"); + subst("cublasDgemvStridedBatched", "hipblasDgemvStridedBatched", "library"); subst("cublasDgemv_v2", "hipblasDgemv", "library"); subst("cublasDgeqrfBatched", "hipblasDgeqrfBatched", "library"); subst("cublasDger", "hipblasDger", "library"); @@ -4116,6 +4117,7 @@ sub simpleSubstitutions { subst("cublasSgemmStridedBatched", "hipblasSgemmStridedBatched", "library"); subst("cublasSgemm_v2", "hipblasSgemm", "library"); subst("cublasSgemv", "hipblasSgemv", "library"); + subst("cublasSgemvStridedBatched", "hipblasSgemvStridedBatched", "library"); subst("cublasSgemv_v2", "hipblasSgemv", "library"); subst("cublasSgeqrfBatched", "hipblasSgeqrfBatched", "library"); subst("cublasSger", "hipblasSger", "library"); @@ -11439,7 +11441,6 @@ sub warnHipOnlyUnsupportedFunctions { "cublasSger_v2_64", "cublasSger_64", "cublasSgemvStridedBatched_64", - "cublasSgemvStridedBatched", "cublasSgemvBatched", "cublasSgemm_v2_64", "cublasSgemm_64", @@ -11583,7 +11584,6 @@ sub warnHipOnlyUnsupportedFunctions { "cublasDger_v2_64", "cublasDger_64", "cublasDgemvStridedBatched_64", - "cublasDgemvStridedBatched", "cublasDgemvBatched", "cublasDgemm_v2_64", "cublasDgemm_64", diff --git a/docs/tables/CUBLAS_API_supported_by_HIP.md b/docs/tables/CUBLAS_API_supported_by_HIP.md index 57edcca7..1efe15f2 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP.md @@ -1083,7 +1083,7 @@ |`cublasDgemm_v2_64`|12.0| | | | | | | | | | |`cublasDgemvBatched`|11.6| | | | | | | | | | |`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0| -|`cublasDgemvStridedBatched`|11.6| | | | | | | | | | +|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | | |`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasDsymm`| | | | |`hipblasDsymm`|3.6.0| | | | | |`cublasDsymm_64`|12.0| | | | | | | | | | @@ -1133,7 +1133,7 @@ |`cublasSgemm_v2_64`|12.0| | | | | | | | | | |`cublasSgemvBatched`|11.6| | | | | | | | | | |`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0| -|`cublasSgemvStridedBatched`|11.6| | | | | | | | | | +|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | | |`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasSsymm`| | | | |`hipblasSsymm`|3.6.0| | | | | |`cublasSsymm_64`|12.0| | | | | | | | | | diff --git a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index 6896433c..a05b1c74 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -1083,7 +1083,7 @@ |`cublasDgemm_v2_64`|12.0| | | | | | | | | | | | | | | | |`cublasDgemvBatched`|11.6| | | | | | | | | | | | | | | | |`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | | -|`cublasDgemvStridedBatched`|11.6| | | | | | | | | | | | | | | | +|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | | | | | | | | |`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasDsymm`| | | | |`hipblasDsymm`|3.6.0| | | | |`rocblas_dsymm`|3.5.0| | | | | |`cublasDsymm_64`|12.0| | | | | | | | | | | | | | | | @@ -1133,7 +1133,7 @@ |`cublasSgemm_v2_64`|12.0| | | | | | | | | | | | | | | | |`cublasSgemvBatched`|11.6| | | | | | | | | | | | | | | | |`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | | -|`cublasSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | | +|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | | | | | | | | |`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasSsymm`| | | | |`hipblasSsymm`|3.6.0| | | | |`rocblas_ssymm`|3.5.0| | | | | |`cublasSsymm_64`|12.0| | | | | | | | | | | | | | | | diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 7db21c89..ddb5cb83 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -456,9 +456,9 @@ const std::map CUDA_BLAS_FUNCTION_MAP { {"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, {"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "rocblas_tssgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, {"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, - {"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, {"cublasSgemvStridedBatched_64", {"hipblasSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, - {"cublasDgemvStridedBatched", {"hipblasDgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasDgemvStridedBatched", {"hipblasDgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, {"cublasDgemvStridedBatched_64", {"hipblasDgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, {"cublasCgemvStridedBatched", {"hipblasCgemvStridedBatched_v2", "rocblas_cgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, {"cublasCgemvStridedBatched_64", {"hipblasCgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, @@ -1455,7 +1455,7 @@ const std::map CUDA_BLAS_FUNCTION_VER_MAP { {"cublasTSTgemvBatched_64", {CUDA_120, CUDA_0, CUDA_0 }}, {"cublasTSSgemvBatched", {CUDA_116, CUDA_0, CUDA_0 }}, {"cublasTSSgemvBatched_64", {CUDA_120, CUDA_0, CUDA_0 }}, - {"cublasSgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }}, + {"cublasSgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11062, CUBLAS_VERSION 110902, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 9 CUBLAS_VER_PATCH 2 {"cublasSgemvStridedBatched_64", {CUDA_120, CUDA_0, CUDA_0 }}, {"cublasDgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }}, {"cublasDgemvStridedBatched_64", {CUDA_120, CUDA_0, CUDA_0 }}, @@ -2062,6 +2062,8 @@ const std::map HIP_BLAS_FUNCTION_VER_MAP { {"hipblasDgemvBatched_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}}, {"hipblasCgemvBatched_v2_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}}, {"hipblasZgemvBatched_v2_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}}, + {"hipblasSgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }}, + {"hipblasDgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }}, {"rocblas_status_to_string", {HIP_3050, HIP_0, HIP_0 }}, {"rocblas_sscal", {HIP_1050, HIP_0, HIP_0 }}, diff --git a/tests/unit_tests/synthetic/libraries/cublas2hipblas.cu b/tests/unit_tests/synthetic/libraries/cublas2hipblas.cu index 25845dc1..034b8c33 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2hipblas.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2hipblas.cu @@ -1703,7 +1703,7 @@ int main() { blasStatus = cublasGemmStridedBatchedEx(blasHandle, transa, transb, m, n, k, aptr, Aptr, Atype, lda, strideA, Bptr, Btype, ldb, strideB, bptr, Cptr, Ctype, ldc, strideC, batchCount, blasComputeType, blasGemmAlgo); #endif -#if CUDA_VERSION >= 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2 +#if CUDA_VERSION > 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2 // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const xarray[], int incx, const cuComplex* beta, cuComplex* const yarray[], int incy, int batchCount); // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasCgemvBatched_v2(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const hipComplex* alpha, const hipComplex* const AP[], int lda, const hipComplex* const x[], int incx, const hipComplex* beta, hipComplex* const y[], int incy, int batchCount); // CHECK: blasStatus = hipblasCgemvBatched_v2(blasHandle, blasOperation, m, n, &complexa, complexAarray_const, lda, complexXarray_const, incx, &complexb, complexYarray, incy, batchCount); diff --git a/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu b/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu index d2b8bfeb..21ef3998 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu @@ -1634,6 +1634,8 @@ int main() { long long int strideA = 0; long long int strideB = 0; long long int strideC = 0; + long long int strideX = 0; + long long int strideY = 0; #if CUDA_VERSION >= 7050 // CHECK: __half* ha = 0; @@ -1875,6 +1877,18 @@ int main() { blasStatus = cublasGemmStridedBatchedEx(blasHandle, transa, transb, m, n, k, aptr, Aptr, Atype, lda, strideA, Bptr, Btype, ldb, strideB, bptr, Cptr, Ctype, ldc, strideC, batchCount, blasComputeType, blasGemmAlgo); #endif +#if CUDA_VERSION > 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2 + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const float* A, int lda, long long int strideA, const float* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount); + // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasSgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const float* alpha, const float* AP, int lda, hipblasStride strideA, const float* x, int incx, hipblasStride stridex, const float* beta, float* y, int incy, hipblasStride stridey, int batchCount); + // CHECK: blasStatus = hipblasSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, &fA, lda, strideA, &fx, incx, strideX, &fb, &fy, incy, strideY, batchCount); + blasStatus = cublasSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, &fA, lda, strideA, &fx, incx, strideX, &fb, &fy, incy, strideY, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const double* alpha, const double* A, int lda, long long int strideA, const double* x, int incx, long long int stridex, const double* beta, double* y, int incy, long long int stridey, int batchCount); + // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const double* alpha, const double* AP, int lda, hipblasStride strideA, const double* x, int incx, hipblasStride stridex, const double* beta, double* y, int incy, hipblasStride stridey, int batchCount); + // CHECK: blasStatus = hipblasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount); + blasStatus = cublasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount); +#endif + #if CUDA_VERSION >= 12000 // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasIsamax_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result); // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasIsamax_64(hipblasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result);