Skip to content

Commit

Permalink
bfloat16 is awful to use
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhammond committed Aug 16, 2023
1 parent cf6412e commit 929f43a
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions Cxx11/xgemm-hipblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

prk::HIP::info info;

#if 0
template <typename T>
__global__ void init(int order, T * C)
{
Expand All @@ -73,6 +74,20 @@ __global__ void init(int order, T * C)
}
}

template <>
__global__ void init(int order, hipblasBfloat16 * A, hipblasBfloat16 * B, hipblasBfloat16 * C)
{
auto i = blockIdx.x * blockDim.x + threadIdx.x;
auto j = blockIdx.y * blockDim.y + threadIdx.y;

if ((i<order) && (j<order)) {
A[i*order+j] = hipblasBfloat16::float_to_bfloat16(i);
B[i*order+j] = hipblasBfloat16::float_to_bfloat16(i);
C[i*order+j] = hipblasBfloat16::float_to_bfloat16(0);
}
}
#endif

template <typename T>
__global__ void init(int order, T * A, T * B, T * C)
{
Expand All @@ -97,8 +112,8 @@ void prk_gemm(const hipblasHandle_t & h,

template <>
void prk_gemm(const hipblasHandle_t & h,
const int order, const __half alpha, const __half beta,
const __half * A, const __half * B, __half * C)
const int order, const hipblasHalf alpha, const hipblasHalf beta,
const hipblasHalf * A, const hipblasHalf * B, hipblasHalf * C)
{
prk::HIP::check( hipblasHgemm(h,
HIPBLAS_OP_N, HIPBLAS_OP_N,
Expand Down Expand Up @@ -200,7 +215,7 @@ void run(const hipblasHandle_t & h, int iterations, int order)
auto nflops = 2.0 * prk::pow(forder,3);
auto is_fp64 = (typeid(T) == typeid(double));
auto is_fp32 = (typeid(T) == typeid(float));
auto is_fp16 = (typeid(T) == typeid(__half));
auto is_fp16 = (typeid(T) == typeid(hipblasHalf));
auto pname = (is_fp64 ? "FP64" :
(is_fp32 ? "FP32" :
(is_fp16 ? "FP16" : "Unknown FP type")));
Expand All @@ -218,7 +233,7 @@ void run(const hipblasHandle_t & h, int iterations, int order)
int main(int argc, char * argv[])
{
std::cout << "Parallel Research Kernels version " << PRKVERSION << std::endl;
std::cout << "C++11/CUBLAS Dense matrix-matrix multiplication: C += A x B" << std::endl;
std::cout << "C++11/HIPBLAS Dense matrix-matrix multiplication: C += A x B" << std::endl;

//////////////////////////////////////////////////////////////////////
/// Read and test input parameters
Expand Down Expand Up @@ -259,7 +274,8 @@ int main(int argc, char * argv[])

hipblasHandle_t h;
prk::HIP::check( hipblasCreate(&h) );
run<__half>(h, iterations, order);
run<hipblasHalf>(h, iterations, order);
//run<hipblasBfloat16>(h, iterations, order);
run<float>(h, iterations, order);
run<double>(h, iterations, order);
prk::HIP::check( hipblasDestroy(h) );
Expand Down

0 comments on commit 929f43a

Please sign in to comment.