-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: enabling oneMKL instead of FPKs #2756
feature: enabling oneMKL instead of FPKs #2756
Conversation
This reverts commit 882c4e5.
e523d06
to
9bc24dc
Compare
{ | ||
::oneapi::fpk::gpu::dgemm_sycl(&_queue, transa, transb, m, n, k, alpha, &a, lda, &b, ldb, beta, &c, ldc, offset_a, offset_b, offset_c); | ||
//mkl::blas::gpu::dgemm_sycl(&_queue, transa, transb, m, n, k, alpha, &a, lda, &b, ldb, beta, &c, ldc, offset_a, offset_b, offset_c); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider using a special deprivation error for such cases
} | ||
|
||
static void xxgemm(const char * transa, const char * transb, const DAAL_INT * p, const DAAL_INT * ny, const DAAL_INT * n, const double * alpha, | ||
const double * a, const DAAL_INT * lda, const double * y, const DAAL_INT * ldy, const double * beta, double * aty, | ||
const DAAL_INT * ldaty) | ||
{ | ||
__DAAL_MKLFN_CALL(blas_, xdgemm, (transa, transb, p, ny, n, alpha, a, lda, y, ldy, beta, aty, ldaty)); | ||
__DAAL_MKLFN_CALL( | ||
blas_, dgemm, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are already changing these calls please consider using static_cast
instead of C-like casts
int old_threads = fpk_serv_set_num_threads_local(1); | ||
__DAAL_MKLFN_CALL(blas_, ssymm, (side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc)); | ||
fpk_serv_set_num_threads_local(old_threads); | ||
int old_threads = mkl_serv_set_num_threads_local(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about wrapping this recurring factor into RAII entity? You can obtain current number of threads in the constructor and release it in the destructor respectively.
fpk_serv_set_num_threads_local(old_threads); | ||
int old_threads = mkl_serv_set_num_threads_local(1); | ||
__DAAL_MKLFN_CALL(lapack_, dorgqr, | ||
((MKL_INT *)(&m), (MKL_INT *)(&n), (MKL_INT *)(&k), a, (MKL_INT *)(&lda), tau, work, (MKL_INT *)(&lwork), (MKL_INT *)info)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should some of the arguments be const pointers to a constant value?
@@ -106,7 +106,7 @@ int uniformRNG(const size_t cn, size_t * r, void * stream, const size_t a, const | |||
} | |||
else | |||
{ | |||
unsigned __int64 * cr = (unsigned __int64 *)r; | |||
unsigned long long * cr = (unsigned long long *)r; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::int64_t
?
} | ||
|
||
static void xcsrmv(const char * transa, const DAAL_INT * m, const DAAL_INT * k, const double * alpha, const char * matdescra, const double * val, | ||
const DAAL_INT * indx, const DAAL_INT * pntrb, const DAAL_INT * pntre, const double * x, const double * beta, double * y) | ||
{ | ||
__DAAL_MKLFN_CALL(spblas_, mkl_dcsrmv, (transa, m, k, alpha, matdescra, val, indx, pntrb, pntre, x, beta, y)); | ||
sparse_matrix_t csrA = NULL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be the RAII entity as well?
@@ -130,17 +136,25 @@ extern "C" | |||
|
|||
static void _daal_mkl_threader_for(DAAL_INT n, DAAL_INT threads_request, void * a, func_type func) | |||
{ | |||
fpk_vsl_serv_threader_for(n, threads_request, a, func); | |||
// // fpk_vsl_serv_threader_for(n, threads_request, a, func); | |||
for (DAAL_INT i = 0; i < n; i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (DAAL_INT i = 0; i < n; i++) | |
for (DAAL_INT i = 0; i < n; ++i) |
auto event_ = queue_.submit([&](sycl::handler& cgh) { | ||
cgh.depends_on({ last_event }); | ||
cgh.parallel_for(nd_range, [=](sycl::nd_item<1> id) { | ||
auto idx = id.get_global_id(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be const
?
rn_gen.uniform_cpu(ctx.selected_ftr_count_, | ||
random_bins_host_ptr + node * ctx.selected_ftr_count_, | ||
rng_engine_list[tree_map_ptr[node]], | ||
0.0f, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the kernel operates with double
s?
// } | ||
|
||
// auto call_sym_eigvals_descending(const la::matrix<Float>& symmetric_matrix, | ||
// std::int64_t eigval_count) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow 🫠😅
Description
Please include a summary of the change. For large or complex changes please include enough information to introduce your change and explain motivation for it.
Changes proposed in this pull request: