Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests][modernize]Make the tests compile with hipSYCL #4

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/batch/axpy_batch_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ namespace {
template <typename fp>
int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it even necessary to not use auto? I think it should also work with auto.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhmm, it doesn't seem to be necessary anymore... I recall that a few weeks ago there was an issue with hipSYCL regarding this.

for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -186,7 +186,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/batch/gemm_batch_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
// Call DPC++ GEMM_BATCH_STRIDE.

// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -181,7 +181,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ namespace {
template <typename fp>
int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -208,7 +208,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/batch/gemm_batch_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ namespace {
template <typename fp>
int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -263,7 +263,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/batch/trsm_batch_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ int test(device *dev, oneapi::mkl::layout layout) {
// Call DPC++ TRSM_BATCH_STRIDE.

// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -173,7 +173,7 @@ int test(device *dev, oneapi::mkl::layout layout) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/extensions/gemm_bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
// Call DPC++ GEMM_BIAS.

// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const& e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BIAS:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -142,7 +142,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BIAS:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/extensions/gemmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
// Call DPC++ GEMMT.

// Catch asynchronous exceptions
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const& e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -121,7 +121,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/extensions/gemmt_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, int n, int k, int lda,
int ldb, int ldc, fp alpha, fp beta) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const& e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -123,7 +123,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/blas/include/reference_blas_templates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ template <typename fp>
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const fp *alpha, const fp *a, const int *lda,
const fp *b, const int *ldb, const fp *beta, fp *c, const int *ldc);

#ifdef NOT_HIPSYCL

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use the ENABLE_HALF_ROUTINES here and get rid of the explicit NOT_HIPSYCL macro?
or instead just use a macro that is defined by hipSYCL anyways?

Copy link
Author

@sbalint98 sbalint98 May 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh.. this shouldn't be here at all... this should be in #6 ... Thanks for noticing :) But I think I'll let this be here, instead of moving it over to the other PR

template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const half *alpha, const half *a, const int *lda,
Expand Down Expand Up @@ -255,7 +255,7 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
oneapi::mkl::aligned_free(bf);
oneapi::mkl::aligned_free(cf);
}

#endif
template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const float *alpha, const float *a, const int *lda,
Expand Down Expand Up @@ -291,7 +291,7 @@ template <typename fpa, typename fpc>
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const fpc *alpha, const fpa *a, const int *lda,
const fpa *b, const int *ldb, const fpc *beta, fpc *c, const int *ldc);

#ifdef NOT_HIPSYCL
template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const float *alpha, const half *a, const int *lda,
Expand All @@ -314,7 +314,7 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
oneapi::mkl::aligned_free(af);
oneapi::mkl::aligned_free(bf);
}

#endif
template <typename fp>
static void symm(CBLAS_LAYOUT layout, CBLAS_SIDE left_right, CBLAS_UPLO uplo, const int *m,
const int *n, const fp *alpha, const fp *a, const int *lda, const fp *b,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit_tests/blas/include/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
#include <CL/sycl.hpp>

namespace std {
#ifdef NOT_HIPSYCL
static cl::sycl::half abs(cl::sycl::half v) {
if (v < cl::sycl::half(0))
return -v;
else
return v;
}
#endif
} // namespace std

// Complex helpers.
Expand Down Expand Up @@ -140,12 +142,12 @@ template <>
uint8_t rand_scalar() {
return std::rand() % 128;
}

#ifdef NOT_HIPSYCL
template <>
half rand_scalar() {
return half(std::rand() % 32000) / half(32000) - half(0.5);
}

#endif
template <typename fp>
static fp rand_scalar(int mag) {
fp tmp = fp(mag) + fp(std::rand()) / fp(RAND_MAX) - fp(0.5);
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/level1/asum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
// Call DPC++ ASUM.

// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const& e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -103,7 +103,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/level1/asum_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ namespace {
template <typename fp, typename fp_res>
int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const& e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -111,7 +111,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/level1/axpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
// Call DPC++ AXPY.

// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -108,7 +108,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/blas/level1/axpy_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ namespace {
template <typename fp>
int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp alpha) {
// Catch asynchronous exceptions.
auto exception_handler = [](exception_list exceptions) {
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
for (std::exception_ptr const &e : exceptions) {
try {
std::rethrow_exception(e);
}
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -112,7 +112,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
Loading