diff --git a/Cxx11/xgemm-hipblas.cc b/Cxx11/xgemm-hipblas.cc index 7d2efa52d..14121f543 100644 --- a/Cxx11/xgemm-hipblas.cc +++ b/Cxx11/xgemm-hipblas.cc @@ -62,6 +62,7 @@ prk::HIP::info info; +#if 0 template __global__ void init(int order, T * C) { @@ -73,6 +74,20 @@ __global__ void init(int order, T * C) } } +template <> +__global__ void init(int order, hipblasBfloat16 * A, hipblasBfloat16 * B, hipblasBfloat16 * C) +{ + auto i = blockIdx.x * blockDim.x + threadIdx.x; + auto j = blockIdx.y * blockDim.y + threadIdx.y; + + if ((i __global__ void init(int order, T * A, T * B, T * C) { @@ -97,8 +112,8 @@ void prk_gemm(const hipblasHandle_t & h, template <> void prk_gemm(const hipblasHandle_t & h, - const int order, const __half alpha, const __half beta, - const __half * A, const __half * B, __half * C) + const int order, const hipblasHalf alpha, const hipblasHalf beta, + const hipblasHalf * A, const hipblasHalf * B, hipblasHalf * C) { prk::HIP::check( hipblasHgemm(h, HIPBLAS_OP_N, HIPBLAS_OP_N, @@ -200,7 +215,7 @@ void run(const hipblasHandle_t & h, int iterations, int order) auto nflops = 2.0 * prk::pow(forder,3); auto is_fp64 = (typeid(T) == typeid(double)); auto is_fp32 = (typeid(T) == typeid(float)); - auto is_fp16 = (typeid(T) == typeid(__half)); + auto is_fp16 = (typeid(T) == typeid(hipblasHalf)); auto pname = (is_fp64 ? "FP64" : (is_fp32 ? "FP32" : (is_fp16 ? "FP16" : "Unknown FP type"))); @@ -218,7 +233,7 @@ void run(const hipblasHandle_t & h, int iterations, int order) int main(int argc, char * argv[]) { std::cout << "Parallel Research Kernels version " << PRKVERSION << std::endl; - std::cout << "C++11/CUBLAS Dense matrix-matrix multiplication: C += A x B" << std::endl; + std::cout << "C++11/HIPBLAS Dense matrix-matrix multiplication: C += A x B" << std::endl; ////////////////////////////////////////////////////////////////////// /// Read and test input parameters @@ -259,7 +274,8 @@ int main(int argc, char * argv[]) hipblasHandle_t h; prk::HIP::check( hipblasCreate(&h) ); - run<__half>(h, iterations, order); + run(h, iterations, order); + //run(h, iterations, order); run(h, iterations, order); run(h, iterations, order); prk::HIP::check( hipblasDestroy(h) );