Skip to content

Commit

Permalink
add test for LAPACKE_sgesdd
Browse files Browse the repository at this point in the history
  • Loading branch information
SMovaghati committed Nov 9, 2023
1 parent 3880a8d commit 0d3a4db
Showing 1 changed file with 58 additions and 5 deletions.
63 changes: 58 additions & 5 deletions apps/test_open_blas.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifdef USE_OPENBLAS
#include <cblas.h>
#include <openblas_config.h>
// #include<lapacke.h> //TODO: commented out for now, until we fix the error in header.
#else
#include <mkl.h>
#endif
Expand All @@ -27,9 +28,16 @@ using Flex_CBLAS_TRANSPOSE = CBLAS_TRANSPOSE;
using Flex_CBLAS_TRANSPOSE = CBLAS_TRANSPOSE;
#endif

#ifdef USE_OPENBLAS
using Flex_LAPACK_INT = int;
#else
using Flex_LAPACK_INT = MKL_INT;
#endif

int test_cblas_snrm2();
int test_cblas_sdot();
int test_cblas_sgemm();
int test_LAPACKE_sgesdd();

// A temporary test just to play with OpenBLAS
int main(int argc, char **argv)
Expand All @@ -43,6 +51,7 @@ int main(int argc, char **argv)
auto errorCode = test_cblas_snrm2();
errorCode += test_cblas_sdot();
errorCode += test_cblas_sgemm();
errorCode += test_LAPACKE_sgesdd();

if (errorCode == 0)
{
Expand Down Expand Up @@ -120,9 +129,8 @@ int test_cblas_sgemm()
std::vector<float> C(m * n);

cblas_sgemm(Flex_CBLAS_ORDER::CblasRowMajor, Flex_CBLAS_TRANSPOSE::CblasNoTrans, Flex_CBLAS_TRANSPOSE::CblasNoTrans,
(Flex_INT)m, (Flex_INT)n,
(Flex_INT)k, alpha, A.data(),
(Flex_INT)lda, B.data(), (Flex_INT)ldb, beta, C.data(), (Flex_INT)ldc);
(Flex_INT)m, (Flex_INT)n, (Flex_INT)k, alpha, A.data(), (Flex_INT)lda, B.data(), (Flex_INT)ldb, beta,
C.data(), (Flex_INT)ldc);

#ifdef USE_OPENBLAS
// Expected result from intelMKL: all the values should be 6.0
Expand All @@ -135,7 +143,7 @@ int test_cblas_sgemm()
return 1;
}
}

#else
printf("test_cblas_sgemm result:\n");
for (auto val : C)
Expand All @@ -147,4 +155,49 @@ int test_cblas_sgemm()

printf("Completed\n-------------------------\n");
return 0;
}
}

int test_LAPACKE_sgesdd()
{
printf("Testing test_LAPACKE_sgesdd... \n");

Flex_LAPACK_INT size = 3;
Flex_LAPACK_INT m = size, k = size, n = size;
Flex_LAPACK_INT lda = size, ldu = size, ldvt = size;

std::vector<float> A(m * k, 1.0);
std::vector<float> S(k * n, 2.0);
std::vector<float> U(m * n);
std::vector<float> VT(m * n);

#ifndef USE_OPENBLAS // TODO: this won't work for OpenBLas until we fix the lapack header
uint32_t errcode = (uint32_t)LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (Flex_LAPACK_INT)m, (Flex_LAPACK_INT)n, A.data(),
(Flex_LAPACK_INT)lda, S.data(), U.data(), (Flex_LAPACK_INT)ldu,
VT.data(), (Flex_LAPACK_INT)ldvt);

#ifdef USE_OPENBLAS
// Expected result from intelMKL: VT: -0.577350, -0.577350, -0.577350, 0.816497, -0.408248, -0.408248, 0.000000,
// -0.707107, 0.707107,
std::vector<float> expectedVT{-0.577350, -0.577350, -0.577350, 0.816497, -0.408248,
-0.408248, 0.000000, -0.707107, 0.707107};
for (size_t i = 0; i < VT.size(); i++)
{
if (std::fabs(VT[i] - expectedVT[i]) > 1.0e-4f)
{
printf("OPEN BLAS value (%f) is not matching with Intel MKL value (%f)... \n\n", VT[i], expectedVT[i]);
printf("Validation FAILED :( \n-------------------------\n");
return 1;
}
}
#else
printf("test_cblas_sgemm result:\n");
for (auto val : VT)
{
printf("%f, ", val);
}
printf("\n\n");
#endif
#endif
printf("Completed\n-------------------------\n");
return 0;
}

0 comments on commit 0d3a4db

Please sign in to comment.