From 2178f183bcd09b63f3bdba0350798b600a183783 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Thu, 26 Sep 2024 22:33:51 +0100 Subject: [PATCH] [HIPIFY][rocBLAS] 64-bit functions support - Step 3 + `rocblas_(s|d|c|z|hsh|hss|tst|tss)gemv_batched_64` support + [fix] Added the missing support for `rocblas_(s|d)gemv_batched` + Updated synthetic tests, the regenerated `hipify-perl`, and `BLAS` `CUDA2HIP` documentation --- bin/hipify-perl | 20 ++-- .../CUBLAS_API_supported_by_HIP_and_ROC.md | 20 ++-- docs/tables/CUBLAS_API_supported_by_ROC.md | 20 ++-- src/CUDA2HIP_BLAS_API_functions.cpp | 30 +++-- .../synthetic/libraries/cublas2hipblas_v2.cu | 4 +- .../synthetic/libraries/cublas2rocblas_v2.cu | 104 ++++++++++++++++-- 6 files changed, 146 insertions(+), 52 deletions(-) diff --git a/bin/hipify-perl b/bin/hipify-perl index d69ef706..b033c089 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -1560,6 +1560,7 @@ sub rocSubstitutions { subst("cublasCgemm_v2", "rocblas_cgemm", "library"); subst("cublasCgemv", "rocblas_cgemv", "library"); subst("cublasCgemvBatched", "rocblas_cgemv_batched", "library"); + subst("cublasCgemvBatched_64", "rocblas_cgemv_batched_64", "library"); subst("cublasCgemvStridedBatched", "rocblas_cgemv_strided_batched", "library"); subst("cublasCgemv_64", "rocblas_cgemv_64", "library"); subst("cublasCgemv_v2", "rocblas_cgemv", "library"); @@ -1672,6 +1673,8 @@ sub rocSubstitutions { subst("cublasDgemmStridedBatched", "rocblas_dgemm_strided_batched", "library"); subst("cublasDgemm_v2", "rocblas_dgemm", "library"); subst("cublasDgemv", "rocblas_dgemv", "library"); + subst("cublasDgemvBatched", "rocblas_dgemv_batched", "library"); + subst("cublasDgemvBatched_64", "rocblas_dgemv_batched_64", "library"); subst("cublasDgemv_64", "rocblas_dgemv_64", "library"); subst("cublasDgemv_v2", "rocblas_dgemv", "library"); subst("cublasDgemv_v2_64", "rocblas_dgemv_64", "library"); @@ -1766,8 +1769,10 @@ sub rocSubstitutions { subst("cublasGetVector", "rocblas_get_vector", "library"); subst("cublasGetVectorAsync", "rocblas_get_vector_async", "library"); subst("cublasHSHgemvBatched", "rocblas_hshgemv_batched", "library"); + subst("cublasHSHgemvBatched_64", "rocblas_hshgemv_batched_64", "library"); subst("cublasHSHgemvStridedBatched", "rocblas_hshgemv_strided_batched", "library"); subst("cublasHSSgemvBatched", "rocblas_hssgemv_batched", "library"); + subst("cublasHSSgemvBatched_64", "rocblas_hssgemv_batched_64", "library"); subst("cublasHSSgemvStridedBatched", "rocblas_hssgemv_strided_batched", "library"); subst("cublasHgemm", "rocblas_hgemm", "library"); subst("cublasHgemmBatched", "rocblas_hgemm_batched", "library"); @@ -1856,6 +1861,8 @@ sub rocSubstitutions { subst("cublasSgemmStridedBatched", "rocblas_sgemm_strided_batched", "library"); subst("cublasSgemm_v2", "rocblas_sgemm", "library"); subst("cublasSgemv", "rocblas_sgemv", "library"); + subst("cublasSgemvBatched", "rocblas_sgemv_batched", "library"); + subst("cublasSgemvBatched_64", "rocblas_sgemv_batched_64", "library"); subst("cublasSgemv_64", "rocblas_sgemv_64", "library"); subst("cublasSgemv_v2", "rocblas_sgemv", "library"); subst("cublasSgemv_v2_64", "rocblas_sgemv_64", "library"); @@ -1924,8 +1931,10 @@ sub rocSubstitutions { subst("cublasStrsv", "rocblas_strsv", "library"); subst("cublasStrsv_v2", "rocblas_strsv", "library"); subst("cublasTSSgemvBatched", "rocblas_tssgemv_batched", "library"); + subst("cublasTSSgemvBatched_64", "rocblas_tssgemv_batched_64", "library"); subst("cublasTSSgemvStridedBatched", "rocblas_tssgemv_strided_batched", "library"); subst("cublasTSTgemvBatched", "rocblas_tstgemv_batched", "library"); + subst("cublasTSTgemvBatched_64", "rocblas_tstgemv_batched_64", "library"); subst("cublasTSTgemvStridedBatched", "rocblas_tstgemv_strided_batched", "library"); subst("cublasZaxpy", "rocblas_zaxpy", "library"); subst("cublasZaxpy_64", "rocblas_zaxpy_64", "library"); @@ -1963,6 +1972,7 @@ sub rocSubstitutions { subst("cublasZgemm_v2", "rocblas_zgemm", "library"); subst("cublasZgemv", "rocblas_zgemv", "library"); subst("cublasZgemvBatched", "rocblas_zgemv_batched", "library"); + subst("cublasZgemvBatched_64", "rocblas_zgemv_batched_64", "library"); subst("cublasZgemvStridedBatched", "rocblas_zgemv_strided_batched", "library"); subst("cublasZgemv_64", "rocblas_zgemv_64", "library"); subst("cublasZgemv_v2", "rocblas_zgemv", "library"); @@ -12490,7 +12500,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasZgerc_64", "cublasZgeqrfBatched", "cublasZgemvStridedBatched_64", - "cublasZgemvBatched_64", "cublasZgemm_v2_64", "cublasZgemm_64", "cublasZgemmStridedBatched_64", @@ -12503,9 +12512,7 @@ sub warnRocOnlyUnsupportedFunctions { "cublasXerbla", "cublasUint8gemmBias", "cublasTSTgemvStridedBatched_64", - "cublasTSTgemvBatched_64", "cublasTSSgemvStridedBatched_64", - "cublasTSSgemvBatched_64", "cublasSwapEx_64", "cublasSwapEx", "cublasStrttp", @@ -12558,8 +12565,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasSgeqrfBatched", "cublasSgemvStridedBatched_64", "cublasSgemvStridedBatched", - "cublasSgemvBatched_64", - "cublasSgemvBatched", "cublasSgemm_v2_64", "cublasSgemm_64", "cublasSgemmStridedBatched_64", @@ -12656,9 +12661,7 @@ sub warnRocOnlyUnsupportedFunctions { "cublasHgemmStridedBatched_64", "cublasHgemmBatched_64", "cublasHSSgemvStridedBatched_64", - "cublasHSSgemvBatched_64", "cublasHSHgemvStridedBatched_64", - "cublasHSHgemvBatched_64", "cublasGetVersion_v2", "cublasGetVersion", "cublasGetVector_64", @@ -12726,8 +12729,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasDgeqrfBatched", "cublasDgemvStridedBatched_64", "cublasDgemvStridedBatched", - "cublasDgemvBatched_64", - "cublasDgemvBatched", "cublasDgemm_v2_64", "cublasDgemm_64", "cublasDgemmStridedBatched_64", @@ -12810,7 +12811,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasCgerc_64", "cublasCgeqrfBatched", "cublasCgemvStridedBatched_64", - "cublasCgemvBatched_64", "cublasCgemm_v2_64", "cublasCgemm_64", "cublasCgemmStridedBatched_64", diff --git a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index 444ccfff..141afe35 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -1032,7 +1032,7 @@ |`cublasCgemm_v2`| | | | |`hipblasCgemm_v2`|6.0.0| | | | |`rocblas_cgemm`|1.5.0| | | | | |`cublasCgemm_v2_64`|12.0| | | | | | | | | | | | | | | | |`cublasCgemvBatched`|11.6| | | |`hipblasCgemvBatched_v2`|6.0.0| | | | |`rocblas_cgemv_batched`|3.5.0| | | | | -|`cublasCgemvBatched_64`|12.0| | | |`hipblasCgemvBatched_v2_64`|6.2.0| | | | | | | | | | | +|`cublasCgemvBatched_64`|12.0| | | |`hipblasCgemvBatched_v2_64`|6.2.0| | | | |`rocblas_cgemv_batched_64`|6.2.0| | | | | |`cublasCgemvStridedBatched`|11.6| | | |`hipblasCgemvStridedBatched_v2`|6.0.0| | | | |`rocblas_cgemv_strided_batched`|3.5.0| | | | | |`cublasCgemvStridedBatched_64`|12.0| | | |`hipblasCgemvStridedBatched_v2_64`|6.2.0| | | | | | | | | | | |`cublasChemm`| | | | |`hipblasChemm_v2`|6.0.0| | | | |`rocblas_chemm`|3.5.0| | | | | @@ -1081,8 +1081,8 @@ |`cublasDgemm_64`|12.0| | | | | | | | | | | | | | | | |`cublasDgemm_v2`| | | | |`hipblasDgemm`|1.8.2| | | | |`rocblas_dgemm`|1.5.0| | | | | |`cublasDgemm_v2_64`|12.0| | | | | | | | | | | | | | | | -|`cublasDgemvBatched`|11.6| | | |`hipblasDgemvBatched`|3.0.0| | | | | | | | | | | -|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | | | | | | | | | +|`cublasDgemvBatched`|11.6| | | |`hipblasDgemvBatched`|3.0.0| | | | |`rocblas_dgemv_batched`|3.5.0| | | | | +|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | | |`rocblas_dgemv_batched_64`|6.2.0| | | | | |`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | | | | | | | | |`cublasDgemvStridedBatched_64`|12.0| | | |`hipblasDgemvStridedBatched_64`|6.2.0| | | | | | | | | | | |`cublasDsymm`| | | | |`hipblasDsymm`|3.6.0| | | | |`rocblas_dsymm`|3.5.0| | | | | @@ -1110,11 +1110,11 @@ |`cublasGemmGroupedBatchedEx`|12.5| | | | | | | | | | | | | | | | |`cublasGemmGroupedBatchedEx_64`|12.5| | | | | | | | | | | | | | | | |`cublasHSHgemvBatched`|11.6| | | | | | | | | |`rocblas_hshgemv_batched`|6.0.0| | | | | -|`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasHSHgemvBatched_64`|12.0| | | | | | | | | |`rocblas_hshgemv_batched_64`|6.2.0| | | | | |`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | |`rocblas_hshgemv_strided_batched`|6.0.0| | | | | |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasHSSgemvBatched`|11.6| | | | | | | | | |`rocblas_hssgemv_batched`|6.0.0| | | | | -|`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasHSSgemvBatched_64`|12.0| | | | | | | | | |`rocblas_hssgemv_batched_64`|6.2.0| | | | | |`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | |`rocblas_hssgemv_strided_batched`|6.0.0| | | | | |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasHgemm`|7.5| | | |`hipblasHgemm`|1.8.2| | | | |`rocblas_hgemm`|1.5.0| | | | | @@ -1133,8 +1133,8 @@ |`cublasSgemm_64`|12.0| | | | | | | | | | | | | | | | |`cublasSgemm_v2`| | | | |`hipblasSgemm`|1.8.2| | | | |`rocblas_sgemm`|1.5.0| | | | | |`cublasSgemm_v2_64`|12.0| | | | | | | | | | | | | | | | -|`cublasSgemvBatched`|11.6| | | |`hipblasSgemvBatched`|1.6.0| | | | | | | | | | | -|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | | | | | | | | | +|`cublasSgemvBatched`|11.6| | | |`hipblasSgemvBatched`|1.6.0| | | | |`rocblas_sgemv_batched`|3.5.0| | | | | +|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | | |`rocblas_sgemv_batched_64`|6.2.0| | | | | |`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | | | | | | | | |`cublasSgemvStridedBatched_64`|12.0| | | |`hipblasSgemvStridedBatched_64`|6.2.0| | | | | | | | | | | |`cublasSsymm`| | | | |`hipblasSsymm`|3.6.0| | | | |`rocblas_ssymm`|3.5.0| | | | | @@ -1160,11 +1160,11 @@ |`cublasStrsm_v2`| | | | |`hipblasStrsm`|1.8.2| | | | |`rocblas_strsm`|1.5.0| | | | | |`cublasStrsm_v2_64`|12.0| | | | | | | | | | | | | | | | |`cublasTSSgemvBatched`|11.6| | | | | | | | | |`rocblas_tssgemv_batched`|6.0.0| | | | | -|`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasTSSgemvBatched_64`|12.0| | | | | | | | | |`rocblas_tssgemv_batched_64`|6.2.0| | | | | |`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | |`rocblas_tssgemv_strided_batched`|6.0.0| | | | | |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasTSTgemvBatched`|11.6| | | | | | | | | |`rocblas_tstgemv_batched`|6.0.0| | | | | -|`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasTSTgemvBatched_64`|12.0| | | | | | | | | |`rocblas_tstgemv_batched_64`|6.2.0| | | | | |`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | |`rocblas_tstgemv_strided_batched`|6.0.0| | | | | |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | | |`cublasZgemm`| | | | |`hipblasZgemm_v2`|6.0.0| | | | |`rocblas_zgemm`|1.5.0| | | | | @@ -1178,7 +1178,7 @@ |`cublasZgemm_v2`| | | | |`hipblasZgemm_v2`|6.0.0| | | | |`rocblas_zgemm`|1.5.0| | | | | |`cublasZgemm_v2_64`|12.0| | | | | | | | | | | | | | | | |`cublasZgemvBatched`|11.6| | | |`hipblasZgemvBatched_v2`|6.0.0| | | | |`rocblas_zgemv_batched`|3.5.0| | | | | -|`cublasZgemvBatched_64`|12.0| | | |`hipblasZgemvBatched_v2_64`|6.2.0| | | | | | | | | | | +|`cublasZgemvBatched_64`|12.0| | | |`hipblasZgemvBatched_v2_64`|6.2.0| | | | |`rocblas_zgemv_batched_64`|6.2.0| | | | | |`cublasZgemvStridedBatched`|11.6| | | |`hipblasZgemvStridedBatched_v2`|6.0.0| | | | |`rocblas_zgemv_strided_batched`|3.5.0| | | | | |`cublasZgemvStridedBatched_64`|12.0| | | |`hipblasZgemvStridedBatched_v2_64`|6.2.0| | | | | | | | | | | |`cublasZhemm`| | | | |`hipblasZhemm_v2`|6.0.0| | | | |`rocblas_zhemm`|3.5.0| | | | | diff --git a/docs/tables/CUBLAS_API_supported_by_ROC.md b/docs/tables/CUBLAS_API_supported_by_ROC.md index 07b4d608..5c0fa35d 100644 --- a/docs/tables/CUBLAS_API_supported_by_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_ROC.md @@ -1032,7 +1032,7 @@ |`cublasCgemm_v2`| | | | |`rocblas_cgemm`|1.5.0| | | | | |`cublasCgemm_v2_64`|12.0| | | | | | | | | | |`cublasCgemvBatched`|11.6| | | |`rocblas_cgemv_batched`|3.5.0| | | | | -|`cublasCgemvBatched_64`|12.0| | | | | | | | | | +|`cublasCgemvBatched_64`|12.0| | | |`rocblas_cgemv_batched_64`|6.2.0| | | | | |`cublasCgemvStridedBatched`|11.6| | | |`rocblas_cgemv_strided_batched`|3.5.0| | | | | |`cublasCgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasChemm`| | | | |`rocblas_chemm`|3.5.0| | | | | @@ -1081,8 +1081,8 @@ |`cublasDgemm_64`|12.0| | | | | | | | | | |`cublasDgemm_v2`| | | | |`rocblas_dgemm`|1.5.0| | | | | |`cublasDgemm_v2_64`|12.0| | | | | | | | | | -|`cublasDgemvBatched`|11.6| | | | | | | | | | -|`cublasDgemvBatched_64`|12.0| | | | | | | | | | +|`cublasDgemvBatched`|11.6| | | |`rocblas_dgemv_batched`|3.5.0| | | | | +|`cublasDgemvBatched_64`|12.0| | | |`rocblas_dgemv_batched_64`|6.2.0| | | | | |`cublasDgemvStridedBatched`|11.6| | | | | | | | | | |`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasDsymm`| | | | |`rocblas_dsymm`|3.5.0| | | | | @@ -1110,11 +1110,11 @@ |`cublasGemmGroupedBatchedEx`|12.5| | | | | | | | | | |`cublasGemmGroupedBatchedEx_64`|12.5| | | | | | | | | | |`cublasHSHgemvBatched`|11.6| | | |`rocblas_hshgemv_batched`|6.0.0| | | | | -|`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | +|`cublasHSHgemvBatched_64`|12.0| | | |`rocblas_hshgemv_batched_64`|6.2.0| | | | | |`cublasHSHgemvStridedBatched`|11.6| | | |`rocblas_hshgemv_strided_batched`|6.0.0| | | | | |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasHSSgemvBatched`|11.6| | | |`rocblas_hssgemv_batched`|6.0.0| | | | | -|`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | +|`cublasHSSgemvBatched_64`|12.0| | | |`rocblas_hssgemv_batched_64`|6.2.0| | | | | |`cublasHSSgemvStridedBatched`|11.6| | | |`rocblas_hssgemv_strided_batched`|6.0.0| | | | | |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasHgemm`|7.5| | | |`rocblas_hgemm`|1.5.0| | | | | @@ -1133,8 +1133,8 @@ |`cublasSgemm_64`|12.0| | | | | | | | | | |`cublasSgemm_v2`| | | | |`rocblas_sgemm`|1.5.0| | | | | |`cublasSgemm_v2_64`|12.0| | | | | | | | | | -|`cublasSgemvBatched`|11.6| | | | | | | | | | -|`cublasSgemvBatched_64`|12.0| | | | | | | | | | +|`cublasSgemvBatched`|11.6| | | |`rocblas_sgemv_batched`|3.5.0| | | | | +|`cublasSgemvBatched_64`|12.0| | | |`rocblas_sgemv_batched_64`|6.2.0| | | | | |`cublasSgemvStridedBatched`|11.6| | | | | | | | | | |`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasSsymm`| | | | |`rocblas_ssymm`|3.5.0| | | | | @@ -1160,11 +1160,11 @@ |`cublasStrsm_v2`| | | | |`rocblas_strsm`|1.5.0| | | | | |`cublasStrsm_v2_64`|12.0| | | | | | | | | | |`cublasTSSgemvBatched`|11.6| | | |`rocblas_tssgemv_batched`|6.0.0| | | | | -|`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | +|`cublasTSSgemvBatched_64`|12.0| | | |`rocblas_tssgemv_batched_64`|6.2.0| | | | | |`cublasTSSgemvStridedBatched`|11.6| | | |`rocblas_tssgemv_strided_batched`|6.0.0| | | | | |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasTSTgemvBatched`|11.6| | | |`rocblas_tstgemv_batched`|6.0.0| | | | | -|`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | +|`cublasTSTgemvBatched_64`|12.0| | | |`rocblas_tstgemv_batched_64`|6.2.0| | | | | |`cublasTSTgemvStridedBatched`|11.6| | | |`rocblas_tstgemv_strided_batched`|6.0.0| | | | | |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasZgemm`| | | | |`rocblas_zgemm`|1.5.0| | | | | @@ -1178,7 +1178,7 @@ |`cublasZgemm_v2`| | | | |`rocblas_zgemm`|1.5.0| | | | | |`cublasZgemm_v2_64`|12.0| | | | | | | | | | |`cublasZgemvBatched`|11.6| | | |`rocblas_zgemv_batched`|3.5.0| | | | | -|`cublasZgemvBatched_64`|12.0| | | | | | | | | | +|`cublasZgemvBatched_64`|12.0| | | |`rocblas_zgemv_batched_64`|6.2.0| | | | | |`cublasZgemvStridedBatched`|11.6| | | |`rocblas_zgemv_strided_batched`|3.5.0| | | | | |`cublasZgemvStridedBatched_64`|12.0| | | | | | | | | | |`cublasZhemm`| | | | |`rocblas_zhemm`|3.5.0| | | | | diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 7853f554..84ebe0bd 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -442,22 +442,22 @@ const std::map CUDA_BLAS_FUNCTION_MAP { {"cublasGemmGroupedBatchedEx_64", {"hipblasGemmGroupedBatchedEx_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, // BATCH GEMV - {"cublasSgemvBatched", {"hipblasSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, - {"cublasSgemvBatched_64", {"hipblasSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, - {"cublasDgemvBatched", {"hipblasDgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, - {"cublasDgemvBatched_64", {"hipblasDgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, + {"cublasSgemvBatched", {"hipblasSgemvBatched", "rocblas_sgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, + {"cublasSgemvBatched_64", {"hipblasSgemvBatched_64", "rocblas_sgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, + {"cublasDgemvBatched", {"hipblasDgemvBatched", "rocblas_dgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, + {"cublasDgemvBatched_64", {"hipblasDgemvBatched_64", "rocblas_dgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, {"cublasCgemvBatched", {"hipblasCgemvBatched_v2", "rocblas_cgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, - {"cublasCgemvBatched_64", {"hipblasCgemvBatched_v2_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, + {"cublasCgemvBatched_64", {"hipblasCgemvBatched_v2_64", "rocblas_cgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, {"cublasZgemvBatched", {"hipblasZgemvBatched_v2", "rocblas_zgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, - {"cublasZgemvBatched_64", {"hipblasZgemvBatched_v2_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, + {"cublasZgemvBatched_64", {"hipblasZgemvBatched_v2_64", "rocblas_zgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}}, {"cublasHSHgemvBatched", {"hipblasHSHgemvBatched", "rocblas_hshgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, - {"cublasHSHgemvBatched_64", {"hipblasHSHgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasHSHgemvBatched_64", {"hipblasHSHgemvBatched_64", "rocblas_hshgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, {"cublasHSSgemvBatched", {"hipblasHSSgemvBatched", "rocblas_hssgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, - {"cublasHSSgemvBatched_64", {"hipblasHSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasHSSgemvBatched_64", {"hipblasHSSgemvBatched_64", "rocblas_hssgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, {"cublasTSTgemvBatched", {"hipblasTSTgemvBatched", "rocblas_tstgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, - {"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "rocblas_tstgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, {"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "rocblas_tssgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, - {"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}}, + {"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "rocblas_tssgemv_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}}, {"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, {"cublasSgemvStridedBatched_64", {"hipblasSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, {"cublasDgemvStridedBatched", {"hipblasDgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}}, @@ -2319,6 +2319,16 @@ const std::map HIP_BLAS_FUNCTION_VER_MAP { {"rocblas_dgemv_64", {HIP_6020, HIP_0, HIP_0 }}, {"rocblas_cgemv_64", {HIP_6020, HIP_0, HIP_0 }}, {"rocblas_zgemv_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_sgemv_batched", {HIP_3050, HIP_0, HIP_0 }}, + {"rocblas_dgemv_batched", {HIP_3050, HIP_0, HIP_0 }}, + {"rocblas_sgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_dgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_cgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_zgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_hshgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_hssgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_tstgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_tssgemv_batched_64", {HIP_6020, HIP_0, HIP_0 }}, }; const std::map HIP_BLAS_FUNCTION_CHANGED_VER_MAP { diff --git a/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu b/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu index 8e081166..9e7f7e54 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu @@ -1888,12 +1888,12 @@ int main() { // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const double* alpha, const double* AP, int lda, hipblasStride strideA, const double* x, int incx, hipblasStride stridex, const double* beta, double* y, int incy, hipblasStride stridey, int batchCount); // CHECK: blasStatus = hipblasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount); blasStatus = cublasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount); - + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const float* const Aarray[], int lda, const float* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount); // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasSgemvBatched(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const float* alpha, const float* const AP[], int lda, const float* const x[], int incx, const float* beta, float* const y[], int incy, int batchCount); // CHECK: blasStatus = hipblasSgemvBatched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount); blasStatus = cublasSgemvBatched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount); - + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const double* alpha, const double* const Aarray[], int lda, const double* const xarray[], int incx, const double* beta, double* const yarray[], int incy, int batchCount); // HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvBatched(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const double* alpha, const double* const AP[], int lda, const double* const x[], int incx, const double* beta, double* const y[], int incy, int batchCount); // CHECK: blasStatus = hipblasDgemvBatched(blasHandle, blasOperation, m, n, &da, dAarray_const, lda, dXarray_const, incx, &db, dYarray, incy, batchCount); diff --git a/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu b/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu index f336ed81..1ffb03fb 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu @@ -171,6 +171,7 @@ int main() { int ku = 0; int64_t ku_64 = 0; int batchCount = 0; + int64_t batchCount_64 = 0; void *image = nullptr; void *image_2 = nullptr; void *valpha = nullptr; @@ -259,12 +260,16 @@ int main() { float fresult = 0; float fparam = 0; - float** fAarray = 0; + float** fAarray = nullptr; const float** const fAarray_const = const_cast(fAarray); - float** fBarray = 0; + float** fBarray = nullptr; const float** const fBarray_const = const_cast(fBarray); - float** fCarray = 0; - float** fTauarray = 0; + float** fXarray = nullptr; + const float** const fXarray_const = const_cast(fXarray); + float** fYarray = nullptr; + const float** const fYarray_const = const_cast(fYarray); + float** fCarray = nullptr; + float** fTauarray = nullptr; double da = 0; double dA = 0; @@ -282,12 +287,16 @@ int main() { double dresult = 0; double dparam = 0; - double** dAarray = 0; + double** dAarray = nullptr; const double** const dAarray_const = const_cast(dAarray); - double** dBarray = 0; + double** dBarray = nullptr; const double** const dBarray_const = const_cast(dBarray); - double** dCarray = 0; - double** dTauarray = 0; + double** dXarray = nullptr; + const double** const dXarray_const = const_cast(dXarray); + double** dYarray = nullptr; + const double** const dYarray_const = const_cast(dYarray); + double** dCarray = nullptr; + double** dTauarray = nullptr; void** voidAarray = nullptr; const void** const voidAarray_const = const_cast(voidAarray); @@ -319,26 +328,42 @@ int main() { // CHECK: rocblas_float_complex** complexAarray = 0; // CHECK: const rocblas_float_complex** const complexAarray_const = const_cast(complexAarray); // CHECK-NEXT: rocblas_float_complex** complexBarray = 0; - // CHECK: const rocblas_float_complex** const complexBarray_const = const_cast(complexBarray); + // CHECK-NEXT: const rocblas_float_complex** const complexBarray_const = const_cast(complexBarray); + // CHECK-NEXT: rocblas_float_complex** complexXarray = 0; + // CHECK-NEXT: const rocblas_float_complex** const complexXarray_const = const_cast(complexXarray); + // CHECK-NEXT: rocblas_float_complex** complexYarray = 0; + // CHECK-NEXT: const rocblas_float_complex** const complexYarray_const = const_cast(complexYarray); // CHECK-NEXT: rocblas_float_complex** complexCarray = 0; // CHECK-NEXT: rocblas_float_complex** complexTauarray = 0; cuComplex** complexAarray = 0; const cuComplex** const complexAarray_const = const_cast(complexAarray); cuComplex** complexBarray = 0; const cuComplex** const complexBarray_const = const_cast(complexBarray); + cuComplex** complexXarray = 0; + const cuComplex** const complexXarray_const = const_cast(complexXarray); + cuComplex** complexYarray = 0; + const cuComplex** const complexYarray_const = const_cast(complexYarray); cuComplex** complexCarray = 0; cuComplex** complexTauarray = 0; // CHECK: rocblas_double_complex** dcomplexAarray = 0; // CHECK: const rocblas_double_complex** const dcomplexAarray_const = const_cast(dcomplexAarray); // CHECK-NEXT: rocblas_double_complex** dcomplexBarray = 0; - // CHECK: const rocblas_double_complex** const dcomplexBarray_const = const_cast(dcomplexBarray); + // CHECK-NEXT: const rocblas_double_complex** const dcomplexBarray_const = const_cast(dcomplexBarray); + // CHECK-NEXT: rocblas_double_complex** dcomplexXarray = 0; + // CHECK-NEXT: const rocblas_double_complex** const dcomplexXarray_const = const_cast(dcomplexXarray); + // CHECK-NEXT: rocblas_double_complex** dcomplexYarray = 0; + // CHECK-NEXT: const rocblas_double_complex** const dcomplexYarray_const = const_cast(dcomplexYarray); // CHECK-NEXT: rocblas_double_complex** dcomplexCarray = 0; // CHECK-NEXT: rocblas_double_complex** dcomplexTauarray = 0; cuDoubleComplex** dcomplexAarray = 0; const cuDoubleComplex** const dcomplexAarray_const = const_cast(dcomplexAarray); cuDoubleComplex** dcomplexBarray = 0; const cuDoubleComplex** const dcomplexBarray_const = const_cast(dcomplexBarray); + cuDoubleComplex** dcomplexXarray = 0; + const cuDoubleComplex** const dcomplexXarray_const = const_cast(dcomplexXarray); + cuDoubleComplex** dcomplexYarray = 0; + const cuDoubleComplex** const dcomplexYarray_const = const_cast(dcomplexYarray); cuDoubleComplex** dcomplexCarray = 0; cuDoubleComplex** dcomplexTauarray = 0; @@ -1932,6 +1957,18 @@ int main() { // CHECK-NEXT: rocblas_datatype C_16BF = rocblas_datatype_bf16_c; cublasDataType_t R_16BF = CUDA_R_16BF; cublasDataType_t C_16BF = CUDA_C_16BF; + + // CHECK: rocblas_bfloat16** bfAarray = 0; + __nv_bfloat16** bfAarray = 0; + // CHECK: const rocblas_bfloat16** const bfAarray_const = const_cast(bfAarray); + const __nv_bfloat16** const bfAarray_const = const_cast(bfAarray); + // CHECK: rocblas_bfloat16** bfXarray = 0; + __nv_bfloat16** bfXarray = 0; + // CHECK: const rocblas_bfloat16** const bfXarray_const = const_cast(bfXarray); + const __nv_bfloat16** const bfXarray_const = const_cast(bfXarray); + __nv_bfloat16** bfYarray = 0; + // CHECK: const rocblas_bfloat16** const bfYarray_const = const_cast(bfYarray); + const __nv_bfloat16** const bfYarray_const = const_cast(bfYarray); #endif #if CUDA_VERSION >= 11040 && CUBLAS_VERSION >= 11600 @@ -1941,6 +1978,18 @@ int main() { const_ch = cublasGetStatusString(blasStatus); #endif +#if CUDA_VERSION > 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2 + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const float* const Aarray[], int lda, const float* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_sgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const float* const A[], rocblas_int lda, const float* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_sgemv_batched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount); + blasStatus = cublasSgemvBatched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const double* alpha, const double* const Aarray[], int lda, const double* const xarray[], int incx, const double* beta, double* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_dgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const double* alpha, const double* const A[], rocblas_int lda, const double* const x[], rocblas_int incx, const double* beta, double* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_dgemv_batched(blasHandle, blasOperation, m, n, &da, dAarray_const, lda, dXarray_const, incx, &db, dYarray, incy, batchCount); + blasStatus = cublasDgemvBatched(blasHandle, blasOperation, m, n, &da, dAarray_const, lda, dXarray_const, incx, &db, dYarray, incy, batchCount); +#endif + #if CUDA_VERSION >= 12000 // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasIsamax_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result); // ROC: ROCBLAS_EXPORT rocblas_status rocblas_isamax_64(rocblas_handle handle, int64_t n, const float* x, int64_t incx, int64_t* result); @@ -2363,6 +2412,41 @@ int main() { // CHECK-NEXT: blasStatus = rocblas_zgemv_64(blasHandle, blasOperation, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexx, incx_64, &dcomplexb, &dcomplexy, incy_64); blasStatus = cublasZgemv_64(blasHandle, blasOperation, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexx, incx_64, &dcomplexb, &dcomplexy, incy_64); blasStatus = cublasZgemv_v2_64(blasHandle, blasOperation, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexx, incx_64, &dcomplexb, &dcomplexy, incy_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const float* alpha, const float* const Aarray[], int64_t lda, const float* const xarray[], int64_t incx, const float* beta, float* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_sgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const float* alpha, const float* const A[], int64_t lda, const float* const x[], int64_t incx, const float* beta, float* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_sgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &fa, fAarray_const, lda_64, fXarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); + blasStatus = cublasSgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &fa, fAarray_const, lda_64, fXarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const cuComplex* alpha, const cuComplex* const Aarray[], int64_t lda, const cuComplex* const xarray[], int64_t incx, const cuComplex* beta, cuComplex* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_cgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], int64_t lda, const rocblas_float_complex* const x[], int64_t incx, const rocblas_float_complex* beta, rocblas_float_complex* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_cgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &complexa, complexAarray_const, lda_64, complexXarray_const, incx_64, &complexb, complexYarray, incy_64, batchCount_64); + blasStatus = cublasCgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &complexa, complexAarray_const, lda_64, complexXarray_const, incx_64, &complexb, complexYarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const cuDoubleComplex* alpha, const cuDoubleComplex* const Aarray[], int64_t lda, const cuDoubleComplex* const xarray[], int64_t incx, const cuDoubleComplex* beta, cuDoubleComplex* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_zgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const rocblas_double_complex* alpha, const rocblas_double_complex* const A[], int64_t lda, const rocblas_double_complex* const x[], int64_t incx, const rocblas_double_complex* beta, rocblas_double_complex* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_zgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexXarray_const, incx_64, &dcomplexb, dcomplexYarray, incy_64, batchCount_64); + blasStatus = cublasZgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexXarray_const, incx_64, &dcomplexb, dcomplexYarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const float* alpha, const __half* const Aarray[], int64_t lda, const __half* const xarray[], int64_t incx, const float* beta, __half* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hshgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const float* alpha, const rocblas_half* const A[], int64_t lda, const rocblas_half* const x[], int64_t incx, const float* beta, rocblas_half* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_hshgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &fa, hAarray_const, lda_64, hxarray_const, incx_64, &fb, hyarray, incy_64, batchCount_64); + blasStatus = cublasHSHgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &fa, hAarray_const, lda_64, hxarray_const, incx_64, &fb, hyarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const float* alpha, const __half* const Aarray[], int64_t lda, const __half* const xarray[], int64_t incx, const float* beta, float* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hssgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const float* alpha, const rocblas_half* const A[], int64_t lda, const rocblas_half* const x[], int64_t incx, const float* beta, float* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_hssgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &fa, hAarray_const, lda_64, hxarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); + blasStatus = cublasHSSgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &fa, hAarray_const, lda_64, hxarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const float* alpha, const __nv_bfloat16* const Aarray[], int64_t lda, const __nv_bfloat16* const xarray[], int64_t incx, const float* beta, __nv_bfloat16* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tstgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const float* alpha, const rocblas_bfloat16* const A[], int64_t lda, const rocblas_bfloat16* const x[], int64_t incx, const float* beta, rocblas_bfloat16* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_tstgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &fa, bfAarray_const, lda_64, bfXarray_const, incx_64, &fb, bfYarray, incy_64, batchCount_64); + blasStatus = cublasTSTgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &fa, bfAarray_const, lda_64, bfXarray_const, incx_64, &fb, bfYarray, incy_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvBatched_64(cublasHandle_t handle, cublasOperation_t trans, int64_t m, int64_t n, const float* alpha, const __nv_bfloat16* const Aarray[], int64_t lda, const __nv_bfloat16* const xarray[], int64_t incx, const float* beta, float* const yarray[], int64_t incy, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_batched_64(rocblas_handle handle, rocblas_operation trans, int64_t m, int64_t n, const float* alpha, const rocblas_bfloat16* const A[], int64_t lda, const rocblas_bfloat16* const x[], int64_t incx, const float* beta, float* const y[], int64_t incy, int64_t batch_count); + // CHECK: blasStatus = rocblas_tssgemv_batched_64(blasHandle, blasOperation, m_64, n_64, &fa, bfAarray_const, lda_64, bfXarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); + blasStatus = cublasTSSgemvBatched_64(blasHandle, blasOperation, m_64, n_64, &fa, bfAarray_const, lda_64, bfXarray_const, incx_64, &fb, fYarray, incy_64, batchCount_64); #endif return 0;