From 5f4e5b18137d806400b094abbbb57818d849197a Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Sun, 22 Oct 2023 19:26:38 +0200 Subject: [PATCH] [HIPIFY][6.0.0][BLAS] Support for ROCm HIP 6.0.0 - Step 20 - `half` and `bfloat16` Functions + `__half` -> `__half` -> `rocblas_half` + `__nv_bfloat16` -> `hip_bfloat16` -> `rocblas_bfloat16` + [rocBLAS] New functions: `rocblas_hs(h|s)gemv_batched` and `rocblas_ts(s|t)gemv_batched` + [fix] Removed a non-existing type `nv_bfloat16` + Updated synthetic tests, the regenerated hipify-perl, and docs --- bin/hipify-perl | 13 +++--- .../CUBLAS_API_supported_by_HIP_and_ROC.md | 8 ++-- docs/tables/CUBLAS_API_supported_by_ROC.md | 8 ++-- .../CUDA_Device_API_supported_by_HIP.md | 3 +- src/CUDA2HIP_BLAS_API_functions.cpp | 12 ++++-- src/CUDA2HIP_Device_types.cpp | 9 ++-- .../synthetic/libraries/cublas2rocblas.cu | 43 +++++++++++++++++++ 7 files changed, 73 insertions(+), 23 deletions(-) diff --git a/bin/hipify-perl b/bin/hipify-perl index 940a74e9..c51d73d5 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -1466,6 +1466,8 @@ sub rocSubstitutions { subst("cublasGetStream_v2", "rocblas_get_stream", "library"); subst("cublasGetVector", "rocblas_get_vector", "library"); subst("cublasGetVectorAsync", "rocblas_get_vector_async", "library"); + subst("cublasHSHgemvBatched", "rocblas_hshgemv_batched", "library"); + subst("cublasHSSgemvBatched", "rocblas_hssgemv_batched", "library"); subst("cublasHgemm", "rocblas_hgemm", "library"); subst("cublasHgemmBatched", "rocblas_hgemm_batched", "library"); subst("cublasHgemmStridedBatched", "rocblas_hgemm_strided_batched", "library"); @@ -1575,6 +1577,8 @@ sub rocSubstitutions { subst("cublasStrsm_v2", "rocblas_strsm", "library"); subst("cublasStrsv", "rocblas_strsv", "library"); subst("cublasStrsv_v2", "rocblas_strsv", "library"); + subst("cublasTSSgemvBatched", "rocblas_tssgemv_batched", "library"); + subst("cublasTSTgemvBatched", "rocblas_tstgemv_batched", "library"); subst("cublasZaxpy", "rocblas_zaxpy", "library"); subst("cublasZaxpy_v2", "rocblas_zaxpy", "library"); subst("cublasZcopy", "rocblas_zcopy", "library"); @@ -2031,6 +2035,8 @@ sub rocSubstitutions { subst("cusparseZgtsvInterleavedBatch_bufferSizeExt", "rocsparse_zgtsv_interleaved_batch_buffer_size", "library"); subst("cusparseZnnz", "rocsparse_znnz", "library"); subst("cusparseZnnz_compress", "rocsparse_znnz_compress", "library"); + subst("__half", "rocblas_half", "device_type"); + subst("__nv_bfloat16", "rocblas_bfloat16", "device_type"); subst("cublas.h", "rocblas.h", "include_cuda_main_header"); subst("cublas_v2.h", "rocblas.h", "include_cuda_main_header_v2"); subst("bsric02Info", "_rocsparse_mat_info", "type"); @@ -4052,6 +4058,7 @@ sub simpleSubstitutions { subst("__half2", "__half2", "device_type"); subst("__half2_raw", "__half2_raw", "device_type"); subst("__half_raw", "__half_raw", "device_type"); + subst("__nv_bfloat16", "hip_bfloat16", "device_type"); subst("caffe2\/core\/common_cudnn.h", "caffe2\/core\/hip\/common_miopen.h", "include"); subst("caffe2\/operators\/spatial_batch_norm_op.h", "caffe2\/operators\/hip\/spatial_batch_norm_op_miopen.hip", "include"); subst("channel_descriptor.h", "hip\/channel_descriptor.h", "include"); @@ -6757,7 +6764,6 @@ sub warnUnsupportedFunctions { "nvrtcGetLTOIRSize", "nvrtcGetLTOIR", "nv_bfloat162", - "nv_bfloat16", "memoryBarrier", "libraryPropertyType_t", "libraryPropertyType", @@ -7969,7 +7975,6 @@ sub warnUnsupportedFunctions { "__nv_bfloat16_raw", "__nv_bfloat162_raw", "__nv_bfloat162", - "__nv_bfloat16", "__curand_umul", "__NV_SATFINITE", "__NV_NOSAT", @@ -9958,11 +9963,9 @@ sub warnRocOnlyUnsupportedFunctions { "cublasTSTgemvStridedBatched_64", "cublasTSTgemvStridedBatched", "cublasTSTgemvBatched_64", - "cublasTSTgemvBatched", "cublasTSSgemvStridedBatched_64", "cublasTSSgemvStridedBatched", "cublasTSSgemvBatched_64", - "cublasTSSgemvBatched", "cublasSwapEx_64", "cublasSwapEx", "cublasStrttp", @@ -10095,11 +10098,9 @@ sub warnRocOnlyUnsupportedFunctions { "cublasHSSgemvStridedBatched_64", "cublasHSSgemvStridedBatched", "cublasHSSgemvBatched_64", - "cublasHSSgemvBatched", "cublasHSHgemvStridedBatched_64", "cublasHSHgemvStridedBatched", "cublasHSHgemvBatched_64", - "cublasHSHgemvBatched", "cublasGetVersion_v2", "cublasGetVersion", "cublasGetVector_64", diff --git a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index 9a363dba..c22872bd 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -797,11 +797,11 @@ |`cublasDtrsm_64`|12.0| | | | | | | | | | | | | | | |`cublasDtrsm_v2`| | | |`hipblasDtrsm`|1.8.2| | | | |`rocblas_dtrsm`|1.5.0| | | | | |`cublasDtrsm_v2_64`|12.0| | | | | | | | | | | | | | | -|`cublasHSHgemvBatched`|11.6| | | | | | | | | | | | | | | +|`cublasHSHgemvBatched`|11.6| | | | | | | | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasHSSgemvBatched`|11.6| | | | | | | | | | | | | | | +|`cublasHSSgemvBatched`|11.6| | | | | | | | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | @@ -845,11 +845,11 @@ |`cublasStrsm_64`|12.0| | | | | | | | | | | | | | | |`cublasStrsm_v2`| | | |`hipblasStrsm`|1.8.2| | | | |`rocblas_strsm`|1.5.0| | | | | |`cublasStrsm_v2_64`|12.0| | | | | | | | | | | | | | | -|`cublasTSSgemvBatched`|11.6| | | | | | | | | | | | | | | +|`cublasTSSgemvBatched`|11.6| | | | | | | | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasTSTgemvBatched`|11.6| | | | | | | | | | | | | | | +|`cublasTSTgemvBatched`|11.6| | | | | | | | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | diff --git a/docs/tables/CUBLAS_API_supported_by_ROC.md b/docs/tables/CUBLAS_API_supported_by_ROC.md index 40395be0..b2663305 100644 --- a/docs/tables/CUBLAS_API_supported_by_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_ROC.md @@ -797,11 +797,11 @@ |`cublasDtrsm_64`|12.0| | | | | | | | | |`cublasDtrsm_v2`| | | |`rocblas_dtrsm`|1.5.0| | | | | |`cublasDtrsm_v2_64`|12.0| | | | | | | | | -|`cublasHSHgemvBatched`|11.6| | | | | | | | | +|`cublasHSHgemvBatched`|11.6| | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvBatched_64`|12.0| | | | | | | | | |`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | -|`cublasHSSgemvBatched`|11.6| | | | | | | | | +|`cublasHSSgemvBatched`|11.6| | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvBatched_64`|12.0| | | | | | | | | |`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | @@ -845,11 +845,11 @@ |`cublasStrsm_64`|12.0| | | | | | | | | |`cublasStrsm_v2`| | | |`rocblas_strsm`|1.5.0| | | | | |`cublasStrsm_v2_64`|12.0| | | | | | | | | -|`cublasTSSgemvBatched`|11.6| | | | | | | | | +|`cublasTSSgemvBatched`|11.6| | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvBatched_64`|12.0| | | | | | | | | |`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | -|`cublasTSTgemvBatched`|11.6| | | | | | | | | +|`cublasTSTgemvBatched`|11.6| | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvBatched_64`|12.0| | | | | | | | | |`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | diff --git a/docs/tables/CUDA_Device_API_supported_by_HIP.md b/docs/tables/CUDA_Device_API_supported_by_HIP.md index 48e65397..4c2f3a8a 100644 --- a/docs/tables/CUDA_Device_API_supported_by_HIP.md +++ b/docs/tables/CUDA_Device_API_supported_by_HIP.md @@ -811,7 +811,7 @@ |`__half2`| | | |`__half2`|1.6.0| | | | | |`__half2_raw`| | | |`__half2_raw`|1.9.0| | | | | |`__half_raw`| | | |`__half_raw`|1.9.0| | | | | -|`__nv_bfloat16`|11.0| | | | | | | | | +|`__nv_bfloat16`|11.0| | |`hip_bfloat16`|3.5.0| | | | | |`__nv_bfloat162`|11.0| | | | | | | | | |`__nv_bfloat162_raw`|11.0| | | | | | | | | |`__nv_bfloat16_raw`|11.0| | | | | | | | | @@ -826,7 +826,6 @@ |`__nv_fp8x4_e5m2`|11.8| | | | | | | | | |`__nv_fp8x4_storage_t`|11.8| | | | | | | | | |`__nv_saturation_t`|11.8| | | | | | | | | -|`nv_bfloat16`|11.0| | | | | | | | | |`nv_bfloat162`|11.0| | | | | | | | | diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 11230eb2..652b943b 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -442,13 +442,13 @@ const std::map CUDA_BLAS_FUNCTION_MAP { {"cublasCgemvBatched_64", {"hipblasCgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, {"cublasZgemvBatched", {"hipblasZgemvBatched_v2", "rocblas_zgemv_batched", CONV_LIB_FUNC, API_BLAS, 7}}, {"cublasZgemvBatched_64", {"hipblasZgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasHSHgemvBatched", {"hipblasHSHgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasHSHgemvBatched", {"hipblasHSHgemvBatched", "rocblas_hshgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasHSHgemvBatched_64", {"hipblasHSHgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasHSSgemvBatched", {"hipblasHSSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasHSSgemvBatched", {"hipblasHSSgemvBatched", "rocblas_hssgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasHSSgemvBatched_64", {"hipblasHSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasTSTgemvBatched", {"hipblasTSTgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasTSTgemvBatched", {"hipblasTSTgemvBatched", "rocblas_tstgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "rocblas_tssgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, {"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, {"cublasSgemvStridedBatched_64", {"hipblasSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, @@ -2096,6 +2096,10 @@ const std::map HIP_BLAS_FUNCTION_VER_MAP { {"rocblas_dtrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}}, {"rocblas_ctrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}}, {"rocblas_ztrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_hshgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_hssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_tstgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_tssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, }; const std::map HIP_BLAS_FUNCTION_CHANGED_VER_MAP { diff --git a/src/CUDA2HIP_Device_types.cpp b/src/CUDA2HIP_Device_types.cpp index f621707a..86489605 100644 --- a/src/CUDA2HIP_Device_types.cpp +++ b/src/CUDA2HIP_Device_types.cpp @@ -25,13 +25,12 @@ THE SOFTWARE. // Maps the names of CUDA Device/Host types to the corresponding HIP types const std::map CUDA_DEVICE_TYPE_NAME_MAP { // float16 Precision Device types - {"__half", {"__half", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, + {"__half", {"__half", "rocblas_half", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, {"__half_raw", {"__half_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, {"__half2", {"__half2", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, {"__half2_raw", {"__half2_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, // Bfloat16 Precision Device types - {"__nv_bfloat16", {"__hip_bfloat16", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}}, - {"nv_bfloat16", {"hip_bfloat16", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}}, + {"__nv_bfloat16", {"hip_bfloat16", "rocblas_bfloat16", CONV_DEVICE_TYPE, API_RUNTIME, 2}}, {"__nv_bfloat16_raw", {"__hip_bfloat16_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}}, {"__nv_bfloat162", {"__hip_bfloat162", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}}, {"nv_bfloat162", {"hip_bfloat162", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}}, @@ -83,4 +82,8 @@ const std::map HIP_DEVICE_TYPE_NAME_VER_MAP { {"__half2", {HIP_1060, HIP_0, HIP_0 }}, {"__half_raw", {HIP_1090, HIP_0, HIP_0 }}, {"__half2_raw", {HIP_1090, HIP_0, HIP_0 }}, + {"hip_bfloat16", {HIP_3050, HIP_0, HIP_0 }}, + + {"rocblas_half", {HIP_1050, HIP_0, HIP_0 }}, + {"rocblas_bfloat16", {HIP_3050, HIP_0, HIP_0 }}, }; diff --git a/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu b/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu index 3a8034bb..b5fe333a 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu @@ -240,6 +240,29 @@ int main() { const float** const fBarray_const = const_cast(fBarray); float** fCarray = 0; float** fTauarray = 0; + float** fyarray = 0; + + // CHECK: rocblas_half** hAarray = 0; + __half** hAarray = 0; + // CHECK: const rocblas_half** const hAarray_const = const_cast(hAarray); + const __half** const hAarray_const = const_cast(hAarray); + // CHECK: rocblas_half** hxarray = 0; + __half** hxarray = 0; + // CHECK: const rocblas_half** const hxarray_const = const_cast(hxarray_const); + const __half** const hxarray_const = const_cast(hxarray_const); + // CHECK: rocblas_half** hyarray = 0; + __half** hyarray = 0; + + // CHECK: rocblas_bfloat16** bf16Aarray = 0; + __nv_bfloat16** bf16Aarray = 0; + // CHECK: const rocblas_bfloat16** const bf16Aarray_const = const_cast(bf16Aarray); + const __nv_bfloat16** const bf16Aarray_const = const_cast(bf16Aarray); + // CHECK: rocblas_bfloat16** bf16xarray = 0; + __nv_bfloat16** bf16xarray = 0; + // CHECK: const rocblas_bfloat16** const bf16xarray_const = const_cast(bf16xarray_const); + const __nv_bfloat16** const bf16xarray_const = const_cast(bf16xarray_const); + // CHECK: rocblas_bfloat16** bf16yarray = 0; + __nv_bfloat16** bf16yarray = 0; double da = 0; double dA = 0; @@ -1770,6 +1793,26 @@ int main() { // ROC: ROCBLAS_EXPORT rocblas_status rocblas_zgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const rocblas_double_complex* alpha, const rocblas_double_complex* A, rocblas_int lda, rocblas_stride strideA, const rocblas_double_complex* x, rocblas_int incx, rocblas_stride stridex, const rocblas_double_complex* beta, rocblas_double_complex* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count); // CHECK: blasStatus = rocblas_zgemv_strided_batched(blasHandle, blasOperation, m, n, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexx, incx, stridex, &dcomplexb, &dcomplexy, incy, stridey, batchCount); blasStatus = cublasZgemvStridedBatched(blasHandle, blasOperation, m, n, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexx, incx, stridex, &dcomplexb, &dcomplexy, incy, stridey, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* const Aarray[], int lda, const __half* const xarray[], int incx, const float* beta, __half* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hshgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const x[], rocblas_int incx, const float* beta, rocblas_half* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_hshgemv_batched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, hyarray, incy, batchCount); + blasStatus = cublasHSHgemvBatched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, hyarray, incy, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* const Aarray[], int lda, const __half* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_hssgemv_batched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, fyarray, incy, batchCount); + blasStatus = cublasHSSgemvBatched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, fyarray, incy, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* const Aarray[], int lda, const __nv_bfloat16* const xarray[], int incx, const float* beta, __nv_bfloat16* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tstgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, rocblas_bfloat16* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_tstgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, bf16yarray, incy, batchCount); + blasStatus = cublasTSTgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, bf16yarray, incy, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* const Aarray[], int lda, const __nv_bfloat16* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_tssgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount); + blasStatus = cublasTSSgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount); #endif return 0;