From cdc34864c06c40f44894e773abd98999a27095e3 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 20 Dec 2023 13:22:54 +0000 Subject: [PATCH 01/21] Fix joint_matrix_mad api call * Remove get_wi_data() and replace with joint_matrix_copy and joint_matrix_apply --- .../blas3/gemm_local_joint_matrix.hpp | 219 +++++++++--------- 1 file changed, 104 insertions(+), 115 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 664eed416..68b7e3a06 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -158,18 +158,14 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); index_t k = a_.get_size_col(); @@ -270,10 +266,12 @@ class Gemm(id.get_group(0)); // The batch index that each workgroup should start working with const index_t x_groups = (get_wg_x_cluster() - 1) / jm_row_frags + 1; const index_t y_groups = (get_wg_y_cluster() - 1) / jm_col_frags + 1; - const index_t wg_batch_id = id.get_group(0) / (x_groups * y_groups); + const index_t wg_batch_id = wg_id / (x_groups * y_groups); // This will disable all workgroups that dont have any batch to work on if (wg_batch_id >= batch_size_) { return; @@ -283,20 +281,23 @@ class Gemm - (a_.get_pointer()) + (wg_batch_id * stridea_); - auto ptr_B = cl::sycl::multi_ptr - (b_.get_pointer()) + (wg_batch_id * strideb_); - auto ptr_C = cl::sycl::multi_ptr - (c_.get_pointer()) + (wg_batch_id * stridec_); + auto ptr_A = cl::sycl::multi_ptr( + a_.get_pointer()) + + (wg_batch_id * stridea_); + auto ptr_B = cl::sycl::multi_ptr( + b_.get_pointer()) + + (wg_batch_id * strideb_); + auto ptr_C = cl::sycl::multi_ptr( + c_.get_pointer()) + + (wg_batch_id * stridec_); auto sg = id.get_sub_group(); const index_t sg_id = sg.get_group_linear_id(); @@ -323,8 +324,7 @@ class Gemm( - id, ofs, s1, s2, s3, s4); + sync_smem(id, ofs, s1, s2, s3, s4); k -= cl_elems; } @@ -473,12 +462,8 @@ class Gemm(id, s2, s4, reg_res); - sync_smem( - id, ofs, s1, s2, s3, s4); + sync_smem(id, ofs, s1, s2, s3, s4); } // store the output @@ -517,11 +502,11 @@ class Gemm PORTBLAS_INLINE void store_output_block(cl::sycl::nd_item<1> id, index_t mc, - index_t nc, OutputPointerType C, - ScratchPointerType scratch, - index_t ldc, - CType (®_res)[frags_per_sg], - const bool out_of_range) noexcept { + index_t nc, OutputPointerType C, + ScratchPointerType scratch, + index_t ldc, + CType (®_res)[frags_per_sg], + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -541,51 +526,49 @@ class Gemm= tile_type::joint_matrix_M && - sg_nc >= tile_type::joint_matrix_N) - ? true - : false - : true; - - if (jm_store_feasible) { - const index_t loop_limit = - (tile_type::joint_matrix_M * tile_type::joint_matrix_N) / sg_size; - - if constexpr (is_beta_zero) { - for (index_t frag = 0; frag < frags_per_sg; frag++) { - for (index_t i = 0; i < loop_limit; i++) { - element_t data_left = - static_cast(get_wi_data(sg, reg_res[frag])[i]); - get_wi_data(sg, float_out)[i] = alpha_ * data_left; - } - - joint_matrix_store(sg, float_out, C, ldc, layout::col_major); - - C += (tile_type::joint_matrix_N * ldc); - } - - } else { - for (index_t frag = 0; frag < frags_per_sg; frag++) { - joint_matrix_load(sg, float_out, C, ldc, layout::col_major); - - for (index_t i = 0; i < loop_limit; i++) { - element_t data_left = - static_cast(get_wi_data(sg, reg_res[frag])[i]); - element_t data_right = get_wi_data(sg, float_out)[i]; - get_wi_data(sg, float_out)[i] = - beta_ * data_right + alpha_ * data_left; - } - - joint_matrix_store(sg, float_out, C, ldc, layout::col_major); - - C += (tile_type::joint_matrix_N * ldc); - } - } - return; - } else if (sg_mc <= 0 || sg_nc <= 0) { - return; - } + // const bool jm_store_feasible = (mc < block_rows || nc < block_cols) + // ? (sg_mc >= tile_type::joint_matrix_M + // && + // sg_nc >= tile_type::joint_matrix_N) + // ? true + // : false + // : true; + + // if (jm_store_feasible) { + + // if constexpr (is_beta_zero) { + // for (index_t frag = 0; frag < frags_per_sg; frag++) { + // joint_matrix_copy(sg, reg_res[frag], float_out); + // joint_matrix_apply(sg, float_out, [=](element_t &x){ + // x *= alpha_; + // }); + // joint_matrix_store(sg, float_out, C, ldc, layout::col_major); + // C += (tile_type::joint_matrix_N * ldc); + // } + + // } else { + // for (index_t frag = 0; frag < frags_per_sg; frag++) { + // joint_matrix_load(sg, float_out, C, ldc, layout::col_major); + // joint_matrix_apply(sg, float_out, [=](element_t &x){ + // x *= beta_; + // }); + // // for (index_t i = 0; i < loop_limit; i++) { + // // element_t data_left = + // // static_cast(get_wi_data(sg, reg_res[frag])[i]); + // // element_t data_right = get_wi_data(sg, float_out)[i]; + // // get_wi_data(sg, float_out)[i] = + // // beta_ * data_right + alpha_ * data_left; + // // } + + // joint_matrix_store(sg, float_out, C, ldc, layout::col_major); + + // C += (tile_type::joint_matrix_N * ldc); + // } + // } + // return; + // } else if (sg_mc <= 0 || sg_nc <= 0) { + // return; + // } id.barrier(cl::sycl::access::fence_space::local_space); @@ -598,13 +581,8 @@ class Gemm(get_wi_data(sg, reg_res[frag])[i]); - } + joint_matrix_copy(sg, reg_res[frag], float_out); + joint_matrix_apply(sg, float_out, [=](element_t &x) { x *= alpha_; }); joint_matrix_store(sg, float_out, new_scratch, sg_store_ld, layout::col_major); @@ -614,17 +592,28 @@ class Gemm= tile_type::joint_matrix_M && + sg_nc >= tile_type::joint_matrix_N) { + for (index_t i = 0; i < loop_limit; i++) { + if constexpr (is_beta_zero) { + *(new_C + i * ldc) = *(new_scratch + i * sg_store_ld); + } else { + element_t data_left = *(new_C + i * ldc); + element_t data_right = *(new_scratch + i * sg_store_ld); + *(new_C + i * ldc) = data_right + beta_ * data_left; + } + } + } else if (sg_mc < tile_type::joint_matrix_M && + sg_nc < tile_type::joint_matrix_N) { if (sg_item_id < sg_mc) { for (index_t i = 0; i < loop_limit; i++) { if constexpr (is_beta_zero) { element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right; + *(new_C + i * ldc) = data_right; } else { element_t data_left = *(new_C + i * ldc); element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = alpha_ * data_right + beta_ * data_left; + *(new_C + i * ldc) = data_right + beta_ * data_left; } } } @@ -633,11 +622,11 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } From 548321cef4d9d41f68c085f62317988519add2dd Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Tue, 26 Dec 2023 13:43:51 +0000 Subject: [PATCH 02/21] Fixed output for matrices with no corner cases --- .../blas3/gemm_local_joint_matrix.hpp | 147 ++++-------------- 1 file changed, 31 insertions(+), 116 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 68b7e3a06..c0de6fcf0 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -162,6 +162,10 @@ class Gemm= m || wg_col >= n); const bool internal = m - wg_row >= block_rows && n - wg_col >= block_cols; - ptr_C += - (wg_row + (sg_id % jm_row_frags) * tile_type::joint_matrix_M) + - (wg_col + (sg_id / jm_row_frags) * tile_type::joint_matrix_N) * ldc; + ptr_C += (wg_row + wg_col * ldc); const index_t mc = m - wg_row; const index_t nc = n - wg_col; @@ -318,6 +320,8 @@ class Gemm(id.get_local_linear_id()); + const index_t local_range = static_cast(id.get_local_range(0)); const index_t sg_id = static_cast(sg.get_group_linear_id()); - const index_t sg_range = static_cast(sg.get_group_linear_range()); - const index_t sg_item_id = static_cast(sg.get_local_linear_id()); - - const index_t sg_mc = - mc - (sg_id % jm_row_frags) * tile_type::joint_matrix_M; - const index_t sg_nc = - nc - (sg_id / jm_row_frags) * tile_type::joint_matrix_N; - // const bool jm_store_feasible = (mc < block_rows || nc < block_cols) - // ? (sg_mc >= tile_type::joint_matrix_M - // && - // sg_nc >= tile_type::joint_matrix_N) - // ? true - // : false - // : true; - - // if (jm_store_feasible) { - - // if constexpr (is_beta_zero) { - // for (index_t frag = 0; frag < frags_per_sg; frag++) { - // joint_matrix_copy(sg, reg_res[frag], float_out); - // joint_matrix_apply(sg, float_out, [=](element_t &x){ - // x *= alpha_; - // }); - // joint_matrix_store(sg, float_out, C, ldc, layout::col_major); - // C += (tile_type::joint_matrix_N * ldc); - // } - - // } else { - // for (index_t frag = 0; frag < frags_per_sg; frag++) { - // joint_matrix_load(sg, float_out, C, ldc, layout::col_major); - // joint_matrix_apply(sg, float_out, [=](element_t &x){ - // x *= beta_; - // }); - // // for (index_t i = 0; i < loop_limit; i++) { - // // element_t data_left = - // // static_cast(get_wi_data(sg, reg_res[frag])[i]); - // // element_t data_right = get_wi_data(sg, float_out)[i]; - // // get_wi_data(sg, float_out)[i] = - // // beta_ * data_right + alpha_ * data_left; - // // } - - // joint_matrix_store(sg, float_out, C, ldc, layout::col_major); - - // C += (tile_type::joint_matrix_N * ldc); - // } - // } - // return; - // } else if (sg_mc <= 0 || sg_nc <= 0) { - // return; - // } - - id.barrier(cl::sycl::access::fence_space::local_space); - scratch += sg_id * tile_type::joint_matrix_M; - const index_t sg_store_ld = sg_range * tile_type::joint_matrix_M; + const index_t output_local_store_offset = + (sg_id % jm_row_frags) * tile_type::joint_matrix_M + + (sg_id / jm_row_frags) * ldsc * tile_type::joint_matrix_N; + const index_t output_local_load_offset = + item_id % block_rows + (item_id / block_rows) * ldsc; + const index_t rows_per_iter = local_range / block_rows; const index_t loop_limit = - sg_nc >= tile_type::joint_matrix_N ? tile_type::joint_matrix_N : sg_nc; + (frags_per_sg == 1 ? block_cols : tile_type::joint_matrix_N) / + rows_per_iter; - for (index_t frag = 0; frag < frags_per_sg; frag++, C += ldc * loop_limit) { + const index_t output_global_outer_offset = ldc * tile_type::joint_matrix_N; + const index_t output_global_inner_offset = ldc * rows_per_iter; + const index_t output_local_inner_offset = ldsc * rows_per_iter; + + for (index_t frag = 0; frag < frags_per_sg; + frag++, C += output_global_outer_offset) { auto new_C = C; - auto new_scratch = scratch; + auto new_scratch = scratch + output_local_load_offset; joint_matrix_copy(sg, reg_res[frag], float_out); joint_matrix_apply(sg, float_out, [=](element_t &x) { x *= alpha_; }); - joint_matrix_store(sg, float_out, new_scratch, sg_store_ld, - layout::col_major); + id.barrier(cl::sycl::access::fence_space::local_space); + + joint_matrix_store(sg, float_out, scratch + output_local_store_offset, + ldsc, layout::col_major); id.barrier(cl::sycl::access::fence_space::local_space); - new_C += sg_item_id; - new_scratch += sg_item_id; - - if (sg_mc >= tile_type::joint_matrix_M && - sg_nc >= tile_type::joint_matrix_N) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - *(new_C + i * ldc) = *(new_scratch + i * sg_store_ld); - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right + beta_ * data_left; - } - } - } else if (sg_mc < tile_type::joint_matrix_M && - sg_nc < tile_type::joint_matrix_N) { - if (sg_item_id < sg_mc) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right + beta_ * data_left; - } - } - } - } else if (sg_mc < tile_type::joint_matrix_M) { - if (sg_item_id < sg_mc) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right + beta_ * data_left; - } - } - } - } else { - if (sg_item_id < tile_type::joint_matrix_M) { - for (index_t i = 0; i < loop_limit; i++) { - if constexpr (is_beta_zero) { - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right; - } else { - element_t data_left = *(new_C + i * ldc); - element_t data_right = *(new_scratch + i * sg_store_ld); - *(new_C + i * ldc) = data_right + beta_ * data_left; - } - } - } + for (int i = 0; i < loop_limit; i++, new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + *new_C = *new_scratch; } } } From 65d720b890b332b8605a3b86776db09ab7315b71 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Tue, 26 Dec 2023 15:44:52 +0000 Subject: [PATCH 03/21] Added support for corner cases * Need to fix transpose case --- .../blas3/gemm_local_joint_matrix.hpp | 107 +++++++++++++++--- 1 file changed, 94 insertions(+), 13 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index c0de6fcf0..3170420fd 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -320,8 +320,6 @@ class Gemm= block_rows && nc >= tile_type::joint_matrix_N) { + const index_t loop_limit = + (frags_per_sg == 1 ? block_cols : tile_type::joint_matrix_N) / + rows_per_iter; + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch * beta_ * val; + } + } + continue; + } + if (mc < block_rows && nc < tile_type::joint_matrix_N) { + if (item_id < mc) { + const index_t loop_limit = nc; + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } + } + } + continue; + } + if (mc < block_rows) { + if (it_mod_brows < mc) { + const index_t loop_limit = + (frags_per_sg == 1 ? block_cols : tile_type::joint_matrix_N) / + rows_per_iter; + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } + } + } + continue; + } + if (nc < tile_type::joint_matrix_N) { + if (item_id < block_rows) { + const index_t loop_limit = nc; + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } + } + } + continue; + } + } else { + const index_t loop_limit = + (frags_per_sg == 1 ? block_cols : tile_type::joint_matrix_N) / + rows_per_iter; + for (int i = 0; i < loop_limit; i++, + new_C += output_global_inner_offset, + new_scratch += output_local_inner_offset) { + if constexpr (is_beta_zero) + *new_C = *new_scratch; + else { + auto val = *new_C; + *new_C = *new_scratch + beta_ * val; + } + } } } } From 1f1b41cc0c7a352ac7a165f34352a62d51d0a0eb Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 27 Dec 2023 14:53:03 +0000 Subject: [PATCH 04/21] Fixed the transpose implementation --- .../blas3/gemm_local_joint_matrix.hpp | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 3170420fd..8e40dd88d 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -158,10 +158,12 @@ class Gemm cols ? block_rows == 128 ? 2 : 4 : loop_iterations / 2; #pragma unroll for (index_t i = 0; i < loop_iterations; ++i) { if (!do_check<((bs % (wg_size * multiplier)) != 0)>( item_id + i * (wg_size * multiplier) < bs)) continue; - const index_t local_row_ofs = - (i % divisor) * ((wg_size * multiplier) / cols) + - i / divisor * lds * cols; const index_t row_ofs = i * ((wg_size * multiplier) / cols); const bool in_range = do_check(in_row( (item_id * multiplier) / cols, row_ofs)) && @@ -760,7 +761,7 @@ class Gemm( - in_range, ptr + row_ofs * ld, scratch + local_row_ofs, + in_range, ptr + row_ofs * ld, scratch + row_ofs * lds, [&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE { return in_col((item_id * multiplier) % cols, ofs) && in_row((item_id * multiplier) / cols, row_ofs); @@ -784,14 +785,18 @@ class Gemm &id, InputPointerType s2, InputPointerType s4, CType (®_res)[frags_per_sg]) noexcept { using namespace cl::sycl::ext::oneapi::experimental::matrix; + constexpr layout pattern_a = + trans_a ? layout::row_major : layout::col_major; + constexpr layout pattern_b = + trans_b ? layout::row_major : layout::col_major; using AType = joint_matrix; + pattern_a>; using BType = joint_matrix; + pattern_b>; AType inA; BType inB; @@ -803,7 +808,8 @@ class Gemm Date: Thu, 28 Dec 2023 11:32:24 +0000 Subject: [PATCH 05/21] Fixed corner case output --- src/operations/blas3/gemm_local_joint_matrix.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 8e40dd88d..74074073f 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -576,7 +576,7 @@ class Gemm Date: Mon, 1 Jan 2024 13:32:56 +0000 Subject: [PATCH 06/21] Code working for gemm and gemm_batched --- .../blas3/gemm_local_joint_matrix.hpp | 70 ++++++++----------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 74074073f..0b2569790 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -171,7 +171,8 @@ class Gemm( - a_.get_pointer()) + - (wg_batch_id * stridea_); - auto ptr_B = cl::sycl::multi_ptr( - b_.get_pointer()) + - (wg_batch_id * strideb_); - auto ptr_C = cl::sycl::multi_ptr( - c_.get_pointer()) + - (wg_batch_id * stridec_); + auto ptr_A = a_.get_pointer() + wg_batch_id * stridea_; + auto ptr_B = b_.get_pointer() + wg_batch_id * strideb_; + auto ptr_C = c_.get_pointer() + wg_batch_id * stridec_; auto sg = id.get_sub_group(); const index_t sg_id = sg.get_group_linear_id(); @@ -366,16 +356,17 @@ class Gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } else { compute_panel_gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } } else { + using address_t = cl::sycl::access::address_space; auto input_scratch = *reinterpret_cast *>(&scratch); @@ -385,14 +376,14 @@ class Gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } else { compute_panel_gemm( - id, item_id, m, n, k, mc, nc, a_size, b_size, c_size, ptr_A, lda, - ptr_B, ldb, ptr_C, ldc, scratch, s1, s2, s3, s4, out_of_range, - batch_stride, wg_batch_id, batch_size_); + id, item_id, m, n, k, mc, nc, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + scratch, s1, s2, s3, s4, out_of_range, batch_stride, wg_batch_id, + batch_size_); } } } @@ -431,8 +422,7 @@ class Gemm &id, const index_t &item_id, const index_t &m, const index_t &n, const index_t &orig_k, const index_t &mc, - const index_t &nc, const index_t &a_size, const index_t &b_size, - const index_t &c_size, InputPointerType orig_A, const index_t &lda, + const index_t &nc, InputPointerType orig_A, const index_t &lda, InputPointerType orig_B, const index_t &ldb, OutputPointerType orig_C, const index_t &ldc, OutputScratchPointerType s0, InputScratchPointerType s1, InputScratchPointerType s2, @@ -542,12 +532,14 @@ class Gemm 1 ? tile_type::joint_matrix_N : block_cols; for (index_t frag = 0; frag < frags_per_sg; frag++, C += output_global_outer_offset, nc -= tile_type::joint_matrix_N) { const index_t rows_per_iter = - nc < tile_type::joint_matrix_N ? 1 : local_range / block_rows; + nc < nc_conditional ? 1 : local_range / block_rows; const index_t output_global_inner_offset = ldc * rows_per_iter; const index_t output_local_inner_offset = ldsc * rows_per_iter; @@ -565,10 +557,8 @@ class Gemm= block_rows && nc >= tile_type::joint_matrix_N) { - const index_t loop_limit = - (frags_per_sg == 1 ? block_cols : tile_type::joint_matrix_N) / - rows_per_iter; + if (mc >= block_rows && nc >= nc_conditional) { + const index_t loop_limit = nc_conditional / rows_per_iter; for (int i = 0; i < loop_limit; i++, new_C += output_global_inner_offset, new_scratch += output_local_inner_offset) { @@ -581,7 +571,7 @@ class Gemm Date: Fri, 5 Jan 2024 00:12:27 +0000 Subject: [PATCH 07/21] Unrolled the Global Memory write loops --- src/operations/blas3/gemm_local_joint_matrix.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 0b2569790..a4e7dc17d 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -535,6 +535,7 @@ class Gemm 1 ? tile_type::joint_matrix_N : block_cols; +#pragma unroll for (index_t frag = 0; frag < frags_per_sg; frag++, C += output_global_outer_offset, nc -= tile_type::joint_matrix_N) { @@ -559,6 +560,7 @@ class Gemm= block_rows && nc >= nc_conditional) { const index_t loop_limit = nc_conditional / rows_per_iter; +#pragma unroll for (int i = 0; i < loop_limit; i++, new_C += output_global_inner_offset, new_scratch += output_local_inner_offset) { @@ -574,6 +576,7 @@ class Gemm Date: Fri, 5 Jan 2024 15:09:30 +0000 Subject: [PATCH 08/21] Reorder computation loops --- .../blas3/gemm_local_joint_matrix.hpp | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index a4e7dc17d..c3e18496b 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -790,31 +790,29 @@ class Gemm; - AType inA; - BType inB; - const index_t strideA = ldsa; const index_t strideB = ldsb; auto sg = id.get_sub_group(); #pragma unroll - for (index_t frag = 0; frag < frags_per_sg; frag++) { - auto new_B = s2 + frag * (trans_b ? tile_type::joint_matrix_N - : tile_type::joint_matrix_N * ldsb); - auto new_A = s4; + for (index_t i = 0; i < cl_elems / tile_type::joint_matrix_K; i++) { + auto new_B = s2; + AType inA; - for (index_t i = 0; i < cl_elems / tile_type::joint_matrix_K; i++) { - joint_matrix_load(sg, inA, new_A, strideA); // M - joint_matrix_load(sg, inB, new_B, strideB); // N + joint_matrix_load(sg, inA, s4, strideA); // M + for (index_t frag = 0; frag < frags_per_sg; frag++) { + BType inB; + joint_matrix_load(sg, inB, new_B, strideB); // N joint_matrix_mad(sg, reg_res[frag], inA, inB, reg_res[frag]); - - new_A += (trans_a ? tile_type::joint_matrix_K - : tile_type::joint_matrix_K * strideA); - new_B += (trans_b ? tile_type::joint_matrix_K * strideB - : tile_type::joint_matrix_K); + new_B += (trans_b ? tile_type::joint_matrix_N + : tile_type::joint_matrix_N * ldsb); } + s4 += (trans_a ? tile_type::joint_matrix_K + : tile_type::joint_matrix_K * strideA); + s2 += (trans_b ? tile_type::joint_matrix_K * strideB + : tile_type::joint_matrix_K); } } From eca463487ed2fb7daf2e8bf7f9a17568ece50e6d Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Mon, 8 Jan 2024 11:20:08 +0000 Subject: [PATCH 09/21] Updated comments and removed redundant code --- .../blas3/gemm_local_joint_matrix.hpp | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index c3e18496b..358848870 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -164,11 +164,11 @@ class Gemm( item_id, m, n, k, A, lda, B, ldb, s1, s3, out_of_range); id.barrier(cl::sycl::access::fence_space::local_space); - compute_block_gemm(id, s2, s4, reg_res); + compute_block_gemm(id, s2, s4, reg_res); A += cl_elems * (trans_a ? 1 : lda); B += cl_elems * (trans_b ? ldb : 1); @@ -458,7 +458,7 @@ class Gemm( item_id, m, n, k, A, lda, B, ldb, s1, s3, out_of_range); id.barrier(cl::sycl::access::fence_space::local_space); - compute_block_gemm(id, s2, s4, reg_res); + compute_block_gemm(id, s2, s4, reg_res); sync_smem(id, ofs, s1, s2, s3, s4); @@ -489,8 +489,6 @@ class Gemm(id.get_local_linear_id()); - const index_t local_range = static_cast(id.get_local_range(0)); const index_t sg_id = static_cast(sg.get_group_linear_id()); const index_t output_local_store_offset = @@ -540,7 +537,7 @@ class Gemm( item_id + i * (wg_size * multiplier) < bs)) continue; @@ -765,14 +762,14 @@ class Gemm + template PORTBLAS_INLINE void compute_block_gemm( const cl::sycl::nd_item<1> &id, InputPointerType s2, InputPointerType s4, CType (®_res)[frags_per_sg]) noexcept { From 37d26e1d91ff439c02e28ce2f528a0e0b4cdee3d Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 10 Jan 2024 14:05:20 +0000 Subject: [PATCH 10/21] Fixed build with release compiler --- src/operations/blas3/gemm_local_joint_matrix.hpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 358848870..e6ec1e6ba 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -544,7 +544,17 @@ class Gemm(get_wi_data(sg, reg_res[frag])[i]); + } +#else joint_matrix_copy(sg, reg_res[frag], float_out); +#endif joint_matrix_apply(sg, float_out, [=](element_t &x) { x *= alpha_; }); id.barrier(cl::sycl::access::fence_space::local_space); @@ -802,7 +812,11 @@ class Gemm Date: Tue, 16 Jan 2024 14:11:31 +0000 Subject: [PATCH 11/21] Fix loop bounds --- src/operations/blas3/gemm_local_joint_matrix.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index e6ec1e6ba..9550c7892 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -545,10 +545,9 @@ class Gemm(get_wi_data(sg, reg_res[frag])[i]); } @@ -801,9 +800,10 @@ class Gemm Date: Thu, 18 Jan 2024 17:28:25 +0000 Subject: [PATCH 12/21] Add functions for checking lower precision --- test/blas_test.hpp | 24 +++++++++++++++++++++++ test/unittest/blas3/blas3_gemm_common.hpp | 22 +++++++++++++++++++++ test/unittest/blas3/blas3_trsm_test.cpp | 8 ++++++++ 3 files changed, 54 insertions(+) diff --git a/test/blas_test.hpp b/test/blas_test.hpp index d159109db..fef0d60cb 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -223,6 +223,30 @@ static inline void fill_trsm_matrix(std::vector &A, size_t k, } } +/** + * @brief Set to zero the last n bits of a float. + * + * @param val input/output float value. + * @param nbits number of last bit set to zero. It is set by default to 13 since + * this is the difference of the number of bits of the mantissa between floats + * (23) and FP16 / NVIDIA TF32 (10). + */ +static inline void set_to_zero_last_nbits(float &val, int32_t nbits = 13) { + int32_t *int_pntr = reinterpret_cast(&val); + *int_pntr = (*int_pntr >> nbits) << nbits; +} + +/** + * @brief Set to zero the last n bits of floats contained in a vector. + * + * @param val input/output float vector. + * @param nbits number of last bit set to zero. + */ +static inline void set_to_zero_last_nbits(std::vector &vec, + int32_t nbits = 13) { + for (float &val : vec) set_to_zero_last_nbits(val, nbits); +} + /** * @brief Helper class for dumping arguments to a stream, in a format compatible * with google test test names. diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 2cd832a99..ae305cf3d 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -130,6 +130,17 @@ inline void verify_gemm(const gemm_arguments_t arguments) { fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); + + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(a_m); + set_to_zero_last_nbits(b_m); + set_to_zero_last_nbits(c_m_gpu); + set_to_zero_last_nbits(alpha); + set_to_zero_last_nbits(beta); + } + std::vector c_m_cpu = c_m_gpu; // Use system blas to create a reference output @@ -291,6 +302,17 @@ inline void verify_gemm( fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); + + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(a_m); + set_to_zero_last_nbits(b_m); + set_to_zero_last_nbits(c_m_gpu); + set_to_zero_last_nbits(alpha); + set_to_zero_last_nbits(beta); + } + std::vector c_m_cpu = c_m_gpu; // Use system blas to create a reference output diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index 793bd7ee5..76c87e2db 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -62,6 +62,14 @@ void run_test(const combination_t combi) { static_cast(unusedValue)); fill_random(B); + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + set_to_zero_last_nbits(A); + set_to_zero_last_nbits(B); + set_to_zero_last_nbits(alpha); + } + // Create a copy of B to calculate the reference outputs cpu_B = B; reference_blas::trsm(&side, &uplo, &trans, &diag, m, n, From 60fa20a48d0935f020b73dc95a288c9a5b51ea2b Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 25 Jan 2024 16:07:12 +0000 Subject: [PATCH 13/21] Fixed synchronization for double buffering --- src/operations/blas3/gemm_local_joint_matrix.hpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 9550c7892..63810ea59 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -449,8 +449,11 @@ class Gemm(id, ofs, s1, s2, s3, s4); + sync_smem(id, ofs, s1, s2, s3, + s4); k -= cl_elems; } @@ -460,8 +463,11 @@ class Gemm(id, ofs, s1, s2, s3, s4); + sync_smem(id, ofs, s1, s2, s3, + s4); } // store the output From ba4cd8743346755febb009830169f9da4014ee05 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 31 Jan 2024 12:38:23 +0000 Subject: [PATCH 14/21] Restriced VectorSize to 1 for joint_matrix * Updated the load/store file with proper vectorized load/store for future use. --- .../blas3/gemm_load_store_joint_matrix.hpp | 27 ++++++++++-- .../blas3/gemm_local_joint_matrix.hpp | 41 ++++++++----------- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index 81eb0625e..35e43d1f8 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -86,6 +86,7 @@ struct PacketizeJointMatrix { *dest = round_to_tf32(val); } } + /*! @brief Performs a vectorised load using sycl::vec::load when the current * block is internal. In the case where k < the * number of elements being loaded then edge loads will be element wise with @@ -114,6 +115,7 @@ struct PacketizeJointMatrix { } store(packet, dest); } + /*! @brief Store a vector packet into local memory when the source is * transposed. This will untranspose the elements individually when storing so * the data in local memory is always consistent. @@ -156,16 +158,35 @@ struct PacketizeJointMatrix { address_t::local_space>, DestPointerType>::value) { using dtype = cl::sycl::half; - *dest = static_cast(packet[0]); + cl::sycl::vec new_vec{}; + for (index_t i = 0; i < packet_size; i++) { + reinterpret_cast(&new_vec)[i] = + static_cast(reinterpret_cast(&packet)[i]); + } + new_vec.template store( + 0, cl::sycl::multi_ptr(dest)); } else if constexpr (std::is_same, DestPointerType>::value) { using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(packet[0]); + cl::sycl::vec new_vec; + for (index_t i = 0; i < packet_size; i++) { + reinterpret_cast(&new_vec)[i] = + static_cast(reinterpret_cast(&packet)[i]); + } + new_vec.template store( + 0, cl::sycl::multi_ptr(dest)); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; - *dest = round_to_tf32(packet[0]); + using dtype = float; + cl::sycl::vec new_vec; + for (index_t i = 0; i < packet_size; i++) { + reinterpret_cast(&new_vec)[i] = + round_to_tf32(reinterpret_cast(&packet)[i]); + } + new_vec.template store( + 0, cl::sycl::multi_ptr(dest)); } } }; diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 63810ea59..1a0b7524a 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -83,7 +83,6 @@ class Gemm::type; using packetize_t = PacketizeJointMatrix; - using vector_t = typename packetize_t::PacketType; using address_t = cl::sycl::access::address_space; // enable easier access to tile dimensions @@ -156,6 +155,9 @@ class Gemm::value, "This code is only supported for float data type."); + static_assert(VectorSize == 1, + "Vectorization not supported for joint_matrix."); + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = (trans_a ? cl_elems : block_rows) + @@ -366,7 +368,6 @@ class Gemm *>(&scratch); @@ -721,25 +722,22 @@ class Gemm( - item_id + i * (wg_size * multiplier) < bs)) + if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs)) continue; - const index_t col_ofs = i * ((wg_size * multiplier) / rows); + const index_t col_ofs = i * (wg_size / rows); const bool in_range = - do_check( - in_row(((item_id * multiplier) % rows), multiplier - 1)) && + do_check(in_row((item_id % rows), 0)) && do_check( - in_col((item_id * multiplier / rows), col_ofs)); + in_col((item_id / rows), col_ofs)); packetize_t::template load( in_range, ptr + col_ofs * ld, scratch + col_ofs * lds, [&](const index_t &ofs) { - return in_row((item_id * multiplier) % rows, ofs) && - in_col((item_id * multiplier) / rows, col_ofs); + return in_row(item_id % rows, ofs) && + in_col(item_id / rows, col_ofs); }); } } @@ -751,24 +749,21 @@ class Gemm( - item_id + i * (wg_size * multiplier) < bs)) + if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs)) continue; - const index_t row_ofs = i * ((wg_size * multiplier) / cols); - const bool in_range = do_check(in_row( - (item_id * multiplier) / cols, row_ofs)) && - do_check(in_col( - (item_id * multiplier) % cols, multiplier - 1)); + const index_t row_ofs = i * (wg_size / cols); + const bool in_range = + do_check(in_row(item_id / cols, row_ofs)) && + do_check(in_col(item_id % cols, 0)); packetize_t::template load( in_range, ptr + row_ofs * ld, scratch + row_ofs * lds, [&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE { - return in_col((item_id * multiplier) % cols, ofs) && - in_row((item_id * multiplier) / cols, row_ofs); + return in_col(item_id % cols, ofs) && + in_row(item_id / cols, row_ofs); }); } } From 78af1dabeea09a6f56643a06fa1b9289b09e1d81 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 31 Jan 2024 13:23:29 +0000 Subject: [PATCH 15/21] Fixed race condition * Removed extra store function --- .../blas3/gemm_load_store_joint_matrix.hpp | 83 +++++++------------ .../blas3/gemm_local_joint_matrix.hpp | 7 +- 2 files changed, 35 insertions(+), 55 deletions(-) diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index 35e43d1f8..c8e28f864 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -57,18 +57,16 @@ struct PacketizeJointMatrix { /*! @brief Performs a coalesced non-vectorized load when the current block is * not internal. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate) { - value_t val = in_range ? *(src) : value_t{0}; + value_t val = in_range ? *src : value_t{0}; using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, @@ -91,68 +89,51 @@ struct PacketizeJointMatrix { * block is internal. In the case where k < the * number of elements being loaded then edge loads will be element wise with * additional bounds checking. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + */ + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate edge_in_range) { PacketType packet{}; + using address_t = cl::sycl::access::address_space; if (in_range) { - using address_t = cl::sycl::access::address_space; packet.template load( 0, cl::sycl::multi_ptr(src)); + store(packet, dest); } else { + // avoid writing to variable, instead directly write to + // shared local memory to avoid race condition experienced + // with release compiler. #pragma unroll - for (index_t i = 0; i < packet_size; i++) { - reinterpret_cast(&packet)[i] = - edge_in_range(i) ? *(src + i) : value_t{0}; - } - } - store(packet, dest); - } - - /*! @brief Store a vector packet into local memory when the source is - * transposed. This will untranspose the elements individually when storing so - * the data in local memory is always consistent. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { - using address_t = cl::sycl::access::address_space; -#pragma unroll - for (index_t i = 0; i < packet_size; i++) { - value_t val = reinterpret_cast(&packet)[i]; - if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::half; - *(dest + ld * i) = static_cast(val); - } else if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *(dest + ld * i) = static_cast(val); - } else { - using namespace cl::sycl::ext::oneapi::experimental::matrix; - *(dest + ld * i) = round_to_tf32(val); + for (index_t i = 0; i < packet_size; i++, dest++, src++) { + if constexpr (std::is_same, + DestPointerType>::value) { + using dtype = cl::sycl::half; + *dest = static_cast(edge_in_range(i) ? *src : 0); + } else if constexpr (std::is_same, + DestPointerType>::value) { + using dtype = cl::sycl::ext::oneapi::bfloat16; + *dest = static_cast(edge_in_range(i) ? *src : 0); + } else { + using namespace cl::sycl::ext::oneapi::experimental::matrix; + *dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f; + } } } } - /*! @brief Store a vector packet into local memory when the source is not - * transposed. This will use sycl::vec::store function. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { + /*! @brief Store a vector packet into local memory. This will use + * sycl::vec::store function. + */ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 1a0b7524a..3298203f2 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -730,10 +730,9 @@ class Gemm(in_row((item_id % rows), 0)) && - do_check( - in_col((item_id / rows), col_ofs)); + do_check(in_col((item_id / rows), col_ofs)); - packetize_t::template load( + packetize_t::template load( in_range, ptr + col_ofs * ld, scratch + col_ofs * lds, [&](const index_t &ofs) { return in_row(item_id % rows, ofs) && @@ -759,7 +758,7 @@ class Gemm(in_row(item_id / cols, row_ofs)) && do_check(in_col(item_id % cols, 0)); - packetize_t::template load( + packetize_t::template load( in_range, ptr + row_ofs * ld, scratch + row_ofs * lds, [&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE { return in_col(item_id % cols, ofs) && From 9767329c62e843c22408b826a04cd97a1387296c Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Wed, 31 Jan 2024 14:17:09 +0000 Subject: [PATCH 16/21] Fixed compilation error with bfloat16 type --- .../blas3/gemm_load_store_joint_matrix.hpp | 21 +++++++++---------- test/blas_test.hpp | 3 ++- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index c8e28f864..876817158 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -77,8 +77,8 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(val); + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(val); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; *dest = round_to_tf32(val); @@ -119,8 +119,8 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(edge_in_range(i) ? *src : 0); + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(edge_in_range(i) ? *src : 0.f); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; *dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f; @@ -150,14 +150,13 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - cl::sycl::vec new_vec; - for (index_t i = 0; i < packet_size; i++) { - reinterpret_cast(&new_vec)[i] = - static_cast(reinterpret_cast(&packet)[i]); + // sycl::vec doesn't accept bfloat16 as a valid input type + // so we need to write the packet elements individually to + // the shared memory. + using namespace cl::sycl::ext::oneapi; + for (index_t i = 0; i < packet_size; i++, dest++) { + *dest = bfloat16(reinterpret_cast(&packet)[i]); } - new_vec.template store( - 0, cl::sycl::multi_ptr(dest)); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; using dtype = float; diff --git a/test/blas_test.hpp b/test/blas_test.hpp index fef0d60cb..94c3689bf 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -229,7 +229,8 @@ static inline void fill_trsm_matrix(std::vector &A, size_t k, * @param val input/output float value. * @param nbits number of last bit set to zero. It is set by default to 13 since * this is the difference of the number of bits of the mantissa between floats - * (23) and FP16 / NVIDIA TF32 (10). + * (23) and FP16 / NVIDIA TF32 (10). For bfloat16, this value needs to be set to + * 16 to get correct result. */ static inline void set_to_zero_last_nbits(float &val, int32_t nbits = 13) { int32_t *int_pntr = reinterpret_cast(&val); From d1f348b010526d4933379df92d6f637681ea3fad Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 1 Feb 2024 11:47:39 +0000 Subject: [PATCH 17/21] Increase error margins for trsm tests when joint_matrix is used --- benchmark/portblas/blas3/trsm.cpp | 13 +++--- common/include/common/float_comparison.hpp | 54 ++++++++++++++++------ test/unittest/blas3/blas3_trsm_test.cpp | 2 +- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/benchmark/portblas/blas3/trsm.cpp b/benchmark/portblas/blas3/trsm.cpp index 0afd170ec..e059e3127 100644 --- a/benchmark/portblas/blas3/trsm.cpp +++ b/benchmark/portblas/blas3/trsm.cpp @@ -97,7 +97,7 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side, } std::ostringstream err_stream; - if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) { + if (!utils::compare_vectors(b_temp, x_ref, err_stream, "", true)) { const std::string& err_str = err_stream.str(); state.SkipWithError(err_str.c_str()); *success = false; @@ -181,8 +181,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n, - mem_type).c_str(), + side, uplo, trans, diag, m, n, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success) ->UseRealTime(); } @@ -193,7 +193,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { auto trsm_params = blas_benchmark::utils::get_trsm_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + trsm_params); #ifdef SB_ENABLE_USM register_benchmark( sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params); @@ -201,8 +202,8 @@ void register_benchmark(blas_benchmark::Args& args, } namespace blas_benchmark { -void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, - bool* success) { +void create_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success); } } // namespace blas_benchmark diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 1222ccc41..4d3e49bba 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -120,17 +120,27 @@ scalar_t clamp_to_limits(scalar_t v) { * Indicates the tolerated margin for relative differences */ template -inline scalar_t getRelativeErrorMargin() { +inline scalar_t getRelativeErrorMargin(const bool is_trsm) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), * relative differences of up to 0.002 were observed for float */ - return static_cast(0.005); + scalar_t margin = 0.005; + if (is_trsm) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + // increase error margin for mixed precision calculation + // for trsm operator. + margin = 0.009f; + } + } + return margin; } template <> -inline double getRelativeErrorMargin() { +inline double getRelativeErrorMargin(const bool) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), @@ -142,7 +152,7 @@ inline double getRelativeErrorMargin() { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getRelativeErrorMargin() { +inline cl::sycl::half getRelativeErrorMargin(const bool) { // Measured empirically with gemm return 0.05f; } @@ -152,16 +162,27 @@ inline cl::sycl::half getRelativeErrorMargin() { * scalars are close to 0) */ template -inline scalar_t getAbsoluteErrorMargin() { +inline scalar_t getAbsoluteErrorMargin(const bool is_trsm) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 0.0006 were observed for float */ - return 0.001f; + scalar_t margin = 0.001f; + if (is_trsm) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && std::is_same::value && + *en_joint_matrix == '1') { + // increase error margin for mixed precision calculation + // for trsm operator. + margin = 0.009f; + } + } + + return margin; } template <> -inline double getAbsoluteErrorMargin() { +inline double getAbsoluteErrorMargin(const bool) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 10^-12 were observed for double @@ -171,7 +192,7 @@ inline double getAbsoluteErrorMargin() { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getAbsoluteErrorMargin() { +inline cl::sycl::half getAbsoluteErrorMargin(const bool) { // Measured empirically with gemm. return 1.0f; } @@ -181,7 +202,8 @@ inline cl::sycl::half getAbsoluteErrorMargin() { * Compare two scalars and returns false if the difference is not acceptable. */ template -inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { +inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2, + const bool is_trsm = false) { // Shortcut, also handles case where both are zero if (scalar1 == scalar2) { return true; @@ -196,12 +218,13 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || - absolute_diff < getAbsoluteErrorMargin()) { - return (absolute_diff < getAbsoluteErrorMargin()); + absolute_diff < getAbsoluteErrorMargin(is_trsm)) { + return (absolute_diff < getAbsoluteErrorMargin(is_trsm)); } // Use relative error const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2); - return (absolute_diff / absolute_sum) < getRelativeErrorMargin(); + return (absolute_diff / absolute_sum) < + getRelativeErrorMargin(is_trsm); } /** @@ -215,7 +238,8 @@ template inline bool compare_vectors(std::vector const& vec, std::vector const& ref, std::ostream& err_stream = std::cerr, - std::string end_line = "\n") { + std::string end_line = "\n", + const bool is_trsm = false) { if (vec.size() != ref.size()) { err_stream << "Error: tried to compare vectors of different sizes" << std::endl; @@ -223,7 +247,7 @@ inline bool compare_vectors(std::vector const& vec, } for (int i = 0; i < vec.size(); ++i) { - if (!almost_equal(vec[i], ref[i])) { + if (!almost_equal(vec[i], ref[i], is_trsm)) { err_stream << "Value mismatch at index " << i << ": " << vec[i] << "; expected " << ref[i] << end_line; return false; @@ -244,7 +268,7 @@ template inline bool compare_vectors(std::vector> const& vec, std::vector> const& ref, std::ostream& err_stream = std::cerr, - std::string end_line = "\n") { + std::string end_line = "\n", bool is_trsm = false) { if (vec.size() != ref.size()) { err_stream << "Error: tried to compare vectors of different sizes" << std::endl; diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index 76c87e2db..b0890d9e1 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -92,7 +92,7 @@ void run_test(const combination_t combi) { blas::helper::copy_to_host(q, b_gpu, B.data(), B.size()); sb_handle.wait(event); - bool isAlmostEqual = utils::compare_vectors(cpu_B, B); + bool isAlmostEqual = utils::compare_vectors(cpu_B, B, std::cerr, "", true); ASSERT_TRUE(isAlmostEqual); From 0c0adaa0356410ece4da1a4386a1e81e3dce07a7 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 1 Feb 2024 16:30:05 +0000 Subject: [PATCH 18/21] Address feedback --- benchmark/portblas/blas3/trsm.cpp | 8 +++- common/include/common/float_comparison.hpp | 54 ++++++++-------------- test/unittest/blas3/blas3_trsm_test.cpp | 7 ++- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/benchmark/portblas/blas3/trsm.cpp b/benchmark/portblas/blas3/trsm.cpp index e059e3127..8d28cec4c 100644 --- a/benchmark/portblas/blas3/trsm.cpp +++ b/benchmark/portblas/blas3/trsm.cpp @@ -97,7 +97,13 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side, } std::ostringstream err_stream; - if (!utils::compare_vectors(b_temp, x_ref, err_stream, "", true)) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (!utils::compare_vectors(b_temp, x_ref, err_stream, "", + (en_joint_matrix != NULL) && + (std::is_same::value) && + (*en_joint_matrix == '1') + ? 2 + : 1)) { const std::string& err_str = err_stream.str(); state.SkipWithError(err_str.c_str()); *success = false; diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 4d3e49bba..557938b5e 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -120,27 +120,20 @@ scalar_t clamp_to_limits(scalar_t v) { * Indicates the tolerated margin for relative differences */ template -inline scalar_t getRelativeErrorMargin(const bool is_trsm) { +inline scalar_t getRelativeErrorMargin(const int32_t margin_multiplier = 1) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), * relative differences of up to 0.002 were observed for float */ scalar_t margin = 0.005; - if (is_trsm) { - const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); - if (en_joint_matrix != NULL && std::is_same::value && - *en_joint_matrix == '1') { - // increase error margin for mixed precision calculation - // for trsm operator. - margin = 0.009f; - } - } - return margin; + // increase error margin for mixed precision calculation + // for trsm operator. + return margin * margin_multiplier; } template <> -inline double getRelativeErrorMargin(const bool) { +inline double getRelativeErrorMargin(const int32_t) { /* Measured empirically with gemm. The dimensions of the matrices (even k) * don't seem to have an impact on the observed relative differences * In the cases where the relative error is relevant (non close to zero), @@ -152,7 +145,7 @@ inline double getRelativeErrorMargin(const bool) { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getRelativeErrorMargin(const bool) { +inline cl::sycl::half getRelativeErrorMargin(const int32_t) { // Measured empirically with gemm return 0.05f; } @@ -162,27 +155,19 @@ inline cl::sycl::half getRelativeErrorMargin(const bool) { * scalars are close to 0) */ template -inline scalar_t getAbsoluteErrorMargin(const bool is_trsm) { +inline scalar_t getAbsoluteErrorMargin(const int32_t margin_multiplier = 1) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 0.0006 were observed for float */ scalar_t margin = 0.001f; - if (is_trsm) { - const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); - if (en_joint_matrix != NULL && std::is_same::value && - *en_joint_matrix == '1') { - // increase error margin for mixed precision calculation - // for trsm operator. - margin = 0.009f; - } - } - - return margin; + // increase error margin for mixed precision calculation + // for trsm operator. + return margin * margin_multiplier; } template <> -inline double getAbsoluteErrorMargin(const bool) { +inline double getAbsoluteErrorMargin(const int32_t) { /* Measured empirically with gemm. * In the cases where the relative error is irrelevant (close to zero), * absolute differences of up to 10^-12 were observed for double @@ -192,7 +177,7 @@ inline double getAbsoluteErrorMargin(const bool) { #ifdef BLAS_DATA_TYPE_HALF template <> -inline cl::sycl::half getAbsoluteErrorMargin(const bool) { +inline cl::sycl::half getAbsoluteErrorMargin(const int32_t) { // Measured empirically with gemm. return 1.0f; } @@ -203,7 +188,7 @@ inline cl::sycl::half getAbsoluteErrorMargin(const bool) { */ template inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2, - const bool is_trsm = false) { + const int32_t margin_multiplier = 1) { // Shortcut, also handles case where both are zero if (scalar1 == scalar2) { return true; @@ -218,13 +203,14 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2, // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || - absolute_diff < getAbsoluteErrorMargin(is_trsm)) { - return (absolute_diff < getAbsoluteErrorMargin(is_trsm)); + absolute_diff < getAbsoluteErrorMargin(margin_multiplier)) { + return (absolute_diff < + getAbsoluteErrorMargin(margin_multiplier)); } // Use relative error const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2); return (absolute_diff / absolute_sum) < - getRelativeErrorMargin(is_trsm); + getRelativeErrorMargin(margin_multiplier); } /** @@ -239,7 +225,7 @@ inline bool compare_vectors(std::vector const& vec, std::vector const& ref, std::ostream& err_stream = std::cerr, std::string end_line = "\n", - const bool is_trsm = false) { + const int32_t margin_multiplier = 1) { if (vec.size() != ref.size()) { err_stream << "Error: tried to compare vectors of different sizes" << std::endl; @@ -247,7 +233,7 @@ inline bool compare_vectors(std::vector const& vec, } for (int i = 0; i < vec.size(); ++i) { - if (!almost_equal(vec[i], ref[i], is_trsm)) { + if (!almost_equal(vec[i], ref[i], margin_multiplier)) { err_stream << "Value mismatch at index " << i << ": " << vec[i] << "; expected " << ref[i] << end_line; return false; @@ -268,7 +254,7 @@ template inline bool compare_vectors(std::vector> const& vec, std::vector> const& ref, std::ostream& err_stream = std::cerr, - std::string end_line = "\n", bool is_trsm = false) { + std::string end_line = "\n") { if (vec.size() != ref.size()) { err_stream << "Error: tried to compare vectors of different sizes" << std::endl; diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index b0890d9e1..39e535eca 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -92,7 +92,12 @@ void run_test(const combination_t combi) { blas::helper::copy_to_host(q, b_gpu, B.data(), B.size()); sb_handle.wait(event); - bool isAlmostEqual = utils::compare_vectors(cpu_B, B, std::cerr, "", true); + bool isAlmostEqual = utils::compare_vectors( + cpu_B, B, std::cerr, "", + (en_joint_matrix != NULL) && (std::is_same::value) && + (*en_joint_matrix == '1') + ? 2 + : 1); ASSERT_TRUE(isAlmostEqual); From 23dff656df027edef2dbd2635fbd707fb2fe5b58 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Tue, 6 Feb 2024 17:24:54 +0000 Subject: [PATCH 19/21] Added joint_matrix tests --- CMakeLists.txt | 1 + .../blas3/gemm_local_joint_matrix.hpp | 5 + test/unittest/CMakeLists.txt | 18 ++ test/unittest/joint_matrix/CMakeLists.txt | 74 +++++ .../joint_matrix/bfloat16_float_16_16_16.cpp | 99 +++++++ .../joint_matrix/bfloat16_float_32_8_16.cpp | 99 +++++++ .../joint_matrix/bfloat16_float_8_32_16.cpp | 99 +++++++ .../joint_matrix/cmake/FindPORTBLAS.cmake | 65 +++++ .../joint_matrix/half_float_16_16_16.cpp | 99 +++++++ .../joint_matrix/half_float_32_8_16.cpp | 99 +++++++ .../joint_matrix/half_float_8_32_16.cpp | 99 +++++++ .../joint_matrix/half_half_16_16_16.cpp | 99 +++++++ .../joint_matrix/half_half_32_8_16.cpp | 99 +++++++ .../joint_matrix/half_half_8_32_16.cpp | 99 +++++++ .../joint_matrix/joint_matrix_common.hpp | 260 ++++++++++++++++++ test/unittest/joint_matrix/launch_gemm.hpp | 247 +++++++++++++++++ .../joint_matrix/tf32_float_16_16_8.cpp | 99 +++++++ 17 files changed, 1660 insertions(+) create mode 100644 test/unittest/joint_matrix/CMakeLists.txt create mode 100644 test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp create mode 100644 test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp create mode 100644 test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp create mode 100644 test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake create mode 100644 test/unittest/joint_matrix/half_float_16_16_16.cpp create mode 100644 test/unittest/joint_matrix/half_float_32_8_16.cpp create mode 100644 test/unittest/joint_matrix/half_float_8_32_16.cpp create mode 100644 test/unittest/joint_matrix/half_half_16_16_16.cpp create mode 100644 test/unittest/joint_matrix/half_half_32_8_16.cpp create mode 100644 test/unittest/joint_matrix/half_half_8_32_16.cpp create mode 100644 test/unittest/joint_matrix/joint_matrix_common.hpp create mode 100644 test/unittest/joint_matrix/launch_gemm.hpp create mode 100644 test/unittest/joint_matrix/tf32_float_16_16_8.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 21340430c..b3fd172d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ export(EXPORT portblas option(BLAS_ENABLE_TESTING "Whether to enable testing" ON) option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF) +option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING) message(STATUS "Tests are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_TESTING OFF) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 3298203f2..440229fef 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -158,6 +158,11 @@ class Gemm 1 && jm_row_frags == num_sub_groups) || + (frags_per_sg == 1 && num_jm_frags == num_sub_groups), + "Joint Matrix Row Fragments needs to map 1:1 with total sub_groups."); + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = (trans_a ? cl_elems : block_rows) + diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index 4cdcf7197..14e852af8 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -143,3 +143,21 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) COMPONENT tests ) endforeach() + +if(${ENABLE_JOINTMATRIX_TESTS}) + if (${DPCPP_SYCL_TARGET} STREQUAL "nvptx64-nvidia-cuda") + string(FIND ${DPCPP_SYCL_ARCH} "_" start_idx) + if(start_idx) + MATH(EXPR start_idx "${start_idx} + 1") + string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val) + endif() + + if (${start_idx} AND ${sm_val} GREATER_EQUAL "80") + add_subdirectory(joint_matrix) + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs with sm_80 arch and above.") + endif() + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs.") + endif() +endif() diff --git a/test/unittest/joint_matrix/CMakeLists.txt b/test/unittest/joint_matrix/CMakeLists.txt new file mode 100644 index 000000000..efba7944c --- /dev/null +++ b/test/unittest/joint_matrix/CMakeLists.txt @@ -0,0 +1,74 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename CMakeLists.txt +# * +# **************************************************************************/ + +set(PORTBLAS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../include) +set(PORTBLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +include(ConfigurePORTBLAS) +include(SYCL) +find_package(PORTBLAS REQUIRED) + +set(PORTBLAS_JOINTMATRIX_TEST ${CMAKE_CURRENT_SOURCE_DIR}) + +include_directories(${PORTBLAS_TEST} ${BLAS_INCLUDE_DIRS}) + +# compiling tests +set(SYCL_UNITTEST_SRCS + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/tf32_float_16_16_8.cpp +) + +foreach(blas_test ${SYCL_UNITTEST_SRCS}) + get_filename_component(test_exec ${blas_test} NAME_WE) + add_executable(joint_matrix_${test_exec}_test ../main.cpp ${blas_test}) + target_compile_definitions(joint_matrix_${test_exec}_test PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) + target_link_libraries(joint_matrix_${test_exec}_test PRIVATE gtest_main Clara::Clara blas::blas PORTBLAS::PORTBLAS) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${SYCL_INCLUDE_DIRS}) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) + target_compile_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + target_link_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + + if(TEST_DEVICE) + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --device ${TEST_DEVICE} --gtest_output=xml:output/) + else() + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --gtest_output=xml:output/) + endif() + message(STATUS "Created google test joint_matrix_${test_exec}_test") + install(TARGETS joint_matrix_${test_exec}_test + RUNTIME + DESTINATION ${CMAKE_INSTALL_BINDIR} + COMPONENT tests + ) +endforeach() diff --git a/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp new file mode 100644 index 000000000..af6af66f0 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm16n16k16); + +template +const auto MediumMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm16n16k16); + +template +const auto LargeMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm16n16k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp new file mode 100644 index 000000000..d53941431 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm32n8k16); + +template +const auto MediumMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm32n8k16); + +template +const auto LargeMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm32n8k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp new file mode 100644 index 000000000..8e086f858 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm8n32k16); + +template +const auto MediumMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm8n32k16); + +template +const auto LargeMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm8n32k16); diff --git a/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake new file mode 100644 index 000000000..f102b0247 --- /dev/null +++ b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake @@ -0,0 +1,65 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename FindPORTBLAS.cmake +# * +# **************************************************************************/ + +find_path(PORTBLAS_INCLUDE_DIR + NAMES portblas.h + PATH_SUFFIXES include + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS include directory" +) + +find_path(PORTBLAS_SRC_DIR + NAMES portblas.hpp + PATH_SUFFIXES src + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS source directory" +) + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PORTBLAS + FOUND_VAR PORTBLAS_FOUND + REQUIRED_VARS PORTBLAS_INCLUDE_DIR + PORTBLAS_SRC_DIR +) + +mark_as_advanced(PORTBLAS_FOUND + PORTBLAS_SRC_DIR + PORTBLAS_INCLUDE_DIR +) + +if(PORTBLAS_FOUND) + set(PORTBLAS_INCLUDE_DIRS + ${PORTBLAS_INCLUDE_DIR} + ${PORTBLAS_SRC_DIR} + ) +endif() + +if(PORTBLAS_FOUND AND NOT TARGET PORTBLAS::PORTBLAS) + add_library(PORTBLAS::PORTBLAS INTERFACE IMPORTED) + set_target_properties(PORTBLAS::PORTBLAS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PORTBLAS_INCLUDE_DIRS}" + ) +endif() diff --git a/test/unittest/joint_matrix/half_float_16_16_16.cpp b/test/unittest/joint_matrix/half_float_16_16_16.cpp new file mode 100644 index 000000000..d45fcfbaf --- /dev/null +++ b/test/unittest/joint_matrix/half_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloat161616); + +template +const auto MediumMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloat161616); + +template +const auto LargeMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloat161616); diff --git a/test/unittest/joint_matrix/half_float_32_8_16.cpp b/test/unittest/joint_matrix/half_float_32_8_16.cpp new file mode 100644 index 000000000..db39802c5 --- /dev/null +++ b/test/unittest/joint_matrix/half_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm32n8k16); + +template +const auto MediumMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm32n8k16); + +template +const auto LargeMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm32n8k16); diff --git a/test/unittest/joint_matrix/half_float_8_32_16.cpp b/test/unittest/joint_matrix/half_float_8_32_16.cpp new file mode 100644 index 000000000..61048be18 --- /dev/null +++ b/test/unittest/joint_matrix/half_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm8n32k16); + +template +const auto MediumMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm8n32k16); + +template +const auto LargeMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm8n32k16); diff --git a/test/unittest/joint_matrix/half_half_16_16_16.cpp b/test/unittest/joint_matrix/half_half_16_16_16.cpp new file mode 100644 index 000000000..91de6a527 --- /dev/null +++ b/test/unittest/joint_matrix/half_half_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm16n16k16); + +template +const auto MediumMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm16n16k16); + +template +const auto LargeMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm16n16k16); diff --git a/test/unittest/joint_matrix/half_half_32_8_16.cpp b/test/unittest/joint_matrix/half_half_32_8_16.cpp new file mode 100644 index 000000000..ad719d8e9 --- /dev/null +++ b/test/unittest/joint_matrix/half_half_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm32n8k16); + +template +const auto MediumMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm32n8k16); + +template +const auto LargeMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm32n8k16); diff --git a/test/unittest/joint_matrix/half_half_8_32_16.cpp b/test/unittest/joint_matrix/half_half_8_32_16.cpp new file mode 100644 index 000000000..090eab1ee --- /dev/null +++ b/test/unittest/joint_matrix/half_half_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm8n32k16); + +template +const auto MediumMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm8n32k16); + +template +const auto LargeMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm8n32k16); diff --git a/test/unittest/joint_matrix/joint_matrix_common.hpp b/test/unittest/joint_matrix/joint_matrix_common.hpp new file mode 100644 index 000000000..7998f9dc0 --- /dev/null +++ b/test/unittest/joint_matrix/joint_matrix_common.hpp @@ -0,0 +1,260 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename joint_matrix_common.hpp + * + **************************************************************************/ + +#include "launch_gemm.hpp" + +template +using joint_matrix_arguments_t = + std::tuple; + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_outType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_outType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + assert(batch_type == gemm_batch_type_t::strided); + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector a_m(buffer_size_a); + std::vector b_m(buffer_size_b); + std::vector c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + + index_t nbits = 13; + if (jm_inType == "bfloat16") { + nbits = 16; + } + set_to_zero_last_nbits(a_m, nbits); + set_to_zero_last_nbits(b_m, nbits); + set_to_zero_last_nbits(c_m_gpu, nbits); + set_to_zero_last_nbits(alpha, nbits); + set_to_zero_last_nbits(beta, nbits); + + std::vector c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::gemm(ta_str, tb_str, m, n, k, alpha, + a_m.data() + i * size_a + offset, lda, + b_m.data() + i * size_b + offset, ldb, beta, + c_m_cpu.data() + i * size_c + offset, ldc); + } + + auto m_a_gpu = blas::helper::allocate(buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate(buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate(buffer_size_c, q); + + auto copy_a = + blas::helper::copy_to_device(q, a_m.data(), m_a_gpu, buffer_size_a); + auto copy_b = + blas::helper::copy_to_device(q, b_m.data(), m_b_gpu, buffer_size_b); + auto copy_c = + blas::helper::copy_to_device(q, c_m_gpu.data(), m_c_gpu, buffer_size_c); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (jm_inType == "half" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = launch_gemm_with_beta<16, 16, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = launch_gemm_with_beta<32, 8, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = launch_gemm_with_beta<8, 32, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "half" && jm_outType == "half") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "bfloat16" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "tf32" && jm_outType == "float") { + using namespace sycl::ext::oneapi::experimental::matrix; + gemm_event = launch_gemm_with_beta<16, 16, 8, precision::tf32, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = + blas::helper::copy_to_host(q, m_c_gpu, c_m_gpu.data(), buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors( + c_m_gpu, c_m_cpu, std::cerr, "", + jm_inType == "half" && jm_outType == "half" ? 3 : 1); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_OutType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template <> +inline void dump_arg(std::ostream& ss, + gemm_batch_type_t batch_type) { + ss << (int)batch_type; +} + +template +static std::string generate_name( + const ::testing::TestParamInfo>& info) { + std::string jm_inType, jm_OutType; + int jm_m, jm_n, jm_k; + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + T alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, + offset, batch, m, n, k, transa, transb, alpha, beta, + ldaMul, ldbMul, ldcMul, batchType); +} + +/** Registers Joint Matrix test for all supported data types (only float for + * now) + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_JOINTMATRIX_TEST(test_suite, combination) \ + BLAS_REGISTER_TEST_FLOAT_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, joint_matrix_arguments_t, \ + combination, generate_name); diff --git a/test/unittest/joint_matrix/launch_gemm.hpp b/test/unittest/joint_matrix/launch_gemm.hpp new file mode 100644 index 000000000..afab6a0b3 --- /dev/null +++ b/test/unittest/joint_matrix/launch_gemm.hpp @@ -0,0 +1,247 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename launch_gemm.hpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "interface/gemm_launcher.hpp" +#include "portblas.hpp" +#include + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<8, 8, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<4, 8, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<8, 16, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, true, true, 128, + Tile<8, 8, 8, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<4, 4, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_transpose( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_trans_a == 't' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 't' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } + return gemm_event; +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_beta( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_beta == (element_t)0) { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } else { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } + return gemm_event; +} diff --git a/test/unittest/joint_matrix/tf32_float_16_16_8.cpp b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp new file mode 100644 index 000000000..249c88c34 --- /dev/null +++ b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename tf32_float_16_16_8.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesTF32Floatm16n16k8); + +template +const auto MediumMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesTF32Floatm16n16k8); + +template +const auto LargeMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesTF32Floatm16n16k8); From fe55302c1fcd58926bed9f4e2ab7858d09e63b13 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 22 Feb 2024 10:13:54 +0000 Subject: [PATCH 20/21] Address test commit feedback * Fix half tests by changing the initialization values * Reduced the no. of tests executed --- README.md | 1 + test/blas_test.hpp | 20 +++++++++++-------- test/unittest/blas3/blas3_trsm_test.cpp | 4 ++-- .../joint_matrix/bfloat16_float_16_16_16.cpp | 10 +++++----- .../joint_matrix/bfloat16_float_32_8_16.cpp | 10 +++++----- .../joint_matrix/bfloat16_float_8_32_16.cpp | 10 +++++----- .../joint_matrix/half_float_16_16_16.cpp | 10 +++++----- .../joint_matrix/half_float_32_8_16.cpp | 10 +++++----- .../joint_matrix/half_float_8_32_16.cpp | 10 +++++----- .../joint_matrix/half_half_16_16_16.cpp | 10 +++++----- .../joint_matrix/half_half_32_8_16.cpp | 10 +++++----- .../joint_matrix/half_half_8_32_16.cpp | 10 +++++----- .../joint_matrix/joint_matrix_common.hpp | 19 ++++++++++++------ .../joint_matrix/tf32_float_16_16_8.cpp | 10 +++++----- 14 files changed, 78 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index bba1a6bc0..2d5e93ae8 100644 --- a/README.md +++ b/README.md @@ -458,6 +458,7 @@ Some of the supported options are: | `CMAKE_INSTALL_PREFIX` | path | Specify the install location, used when invoking `ninja install` | | `BUILD_SHARED_LIBS` | `ON`/`OFF` | Build as shared library (`ON` by default) | | `ENABLE_EXPRESSION_TESTS` | `ON`/`OFF` | Build additional tests that use the header-only framework (e.g to test expression trees); `OFF` by default | +| `ENABLE_JOINTMATRIX_TESTS` | `ON`/`OFF` | Build additional tests that use joint_matrix extension; `OFF` by default | | `BLAS_VERIFY_BENCHMARK` | `ON`/`OFF` | Verify the results of the benchmarks instead of only measuring the performance. See the documentation of the benchmarks for more details. `ON` by default | | `BLAS_MEMPOOL_BENCHMARK` | `ON`/`OFF` | Determines whether to enable the scratchpad memory pool for benchmark execution. `OFF` by default | | `BLAS_ENABLE_CONST_INPUT` | `ON`/`OFF` | Determines whether to enable kernel instantiation with const input buffer (`ON` by default) | diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 94c3689bf..0e7e70e1c 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -225,27 +225,31 @@ static inline void fill_trsm_matrix(std::vector &A, size_t k, /** * @brief Set to zero the last n bits of a float. - * + * @tparam T value type. * @param val input/output float value. * @param nbits number of last bit set to zero. It is set by default to 13 since * this is the difference of the number of bits of the mantissa between floats * (23) and FP16 / NVIDIA TF32 (10). For bfloat16, this value needs to be set to * 16 to get correct result. */ -static inline void set_to_zero_last_nbits(float &val, int32_t nbits = 13) { - int32_t *int_pntr = reinterpret_cast(&val); - *int_pntr = (*int_pntr >> nbits) << nbits; +template +void set_to_zero_last_nbits(T &val, int32_t nbits = 13) { + static_assert(sizeof(T) <= 64); + using integer_t = + std::conditional_t; + integer_t *int_pntr = reinterpret_cast(&val); } /** * @brief Set to zero the last n bits of floats contained in a vector. - * + * @tparam T value type. * @param val input/output float vector. * @param nbits number of last bit set to zero. */ -static inline void set_to_zero_last_nbits(std::vector &vec, - int32_t nbits = 13) { - for (float &val : vec) set_to_zero_last_nbits(val, nbits); +template +void set_to_zero_last_nbits(std::vector &vec, int32_t nbits = 13) { + for (T &val : vec) set_to_zero_last_nbits(val, nbits); } /** diff --git a/test/unittest/blas3/blas3_trsm_test.cpp b/test/unittest/blas3/blas3_trsm_test.cpp index 39e535eca..6f44dab93 100644 --- a/test/unittest/blas3/blas3_trsm_test.cpp +++ b/test/unittest/blas3/blas3_trsm_test.cpp @@ -93,10 +93,10 @@ void run_test(const combination_t combi) { sb_handle.wait(event); bool isAlmostEqual = utils::compare_vectors( - cpu_B, B, std::cerr, "", + cpu_B, B, std::cerr, "\n", (en_joint_matrix != NULL) && (std::is_same::value) && (*en_joint_matrix == '1') - ? 2 + ? 3 : 1); ASSERT_TRUE(isAlmostEqual); diff --git a/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp index af6af66f0..ba4d35f7b 100644 --- a/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp +++ b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesBfloat16Floatm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesBfloat16Floatm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesBfloat16Floatm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp index d53941431..49cc187ff 100644 --- a/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp +++ b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesBfloat16Floatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesBfloat16Floatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesBfloat16Floatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp index 8e086f858..6583fd450 100644 --- a/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp +++ b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesBfloat16Floatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesBfloat16Floatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesBfloat16Floatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_float_16_16_16.cpp b/test/unittest/joint_matrix/half_float_16_16_16.cpp index d45fcfbaf..88b3bac8b 100644 --- a/test/unittest/joint_matrix/half_float_16_16_16.cpp +++ b/test/unittest/joint_matrix/half_float_16_16_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfFloat161616 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfFloat161616 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfFloat161616 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_float_32_8_16.cpp b/test/unittest/joint_matrix/half_float_32_8_16.cpp index db39802c5..b370e67f2 100644 --- a/test/unittest/joint_matrix/half_float_32_8_16.cpp +++ b/test/unittest/joint_matrix/half_float_32_8_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfFloatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfFloatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfFloatm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_float_8_32_16.cpp b/test/unittest/joint_matrix/half_float_8_32_16.cpp index 61048be18..9853332ae 100644 --- a/test/unittest/joint_matrix/half_float_8_32_16.cpp +++ b/test/unittest/joint_matrix/half_float_8_32_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfFloatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfFloatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfFloatm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_half_16_16_16.cpp b/test/unittest/joint_matrix/half_half_16_16_16.cpp index 91de6a527..58241fca0 100644 --- a/test/unittest/joint_matrix/half_half_16_16_16.cpp +++ b/test/unittest/joint_matrix/half_half_16_16_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfHalfm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfHalfm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfHalfm16n16k16 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_half_32_8_16.cpp b/test/unittest/joint_matrix/half_half_32_8_16.cpp index ad719d8e9..e220b884b 100644 --- a/test/unittest/joint_matrix/half_half_32_8_16.cpp +++ b/test/unittest/joint_matrix/half_half_32_8_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfHalfm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfHalfm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfHalfm32n8k16 = ::testing::Combine( ::testing::Values(8), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/half_half_8_32_16.cpp b/test/unittest/joint_matrix/half_half_8_32_16.cpp index 090eab1ee..94a4ebeaf 100644 --- a/test/unittest/joint_matrix/half_half_8_32_16.cpp +++ b/test/unittest/joint_matrix/half_half_8_32_16.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesHalfHalfm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesHalfHalfm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesHalfHalfm8n32k16 = ::testing::Combine( ::testing::Values(32), // jm_n ::testing::Values(16), // jm_n ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n diff --git a/test/unittest/joint_matrix/joint_matrix_common.hpp b/test/unittest/joint_matrix/joint_matrix_common.hpp index 7998f9dc0..b6d9d85df 100644 --- a/test/unittest/joint_matrix/joint_matrix_common.hpp +++ b/test/unittest/joint_matrix/joint_matrix_common.hpp @@ -77,9 +77,18 @@ inline void verify_gemm(const joint_matrix_arguments_t arguments) { std::vector b_m(buffer_size_b); std::vector c_m_gpu(buffer_size_c); - fill_random(a_m); - fill_random(b_m); - fill_random(c_m_gpu); + if (jm_outType == "half") { + // initialize the vectors with positive values + // to avoid test failures for half precision + // accumulation + fill_random_with_range(a_m, scalar_t{1}, scalar_t{3}); + fill_random_with_range(b_m, scalar_t{1}, scalar_t{3}); + fill_random_with_range(c_m_gpu, scalar_t{1}, scalar_t{3}); + } else { + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + } index_t nbits = 13; if (jm_inType == "bfloat16") { @@ -184,9 +193,7 @@ inline void verify_gemm(const joint_matrix_arguments_t arguments) { blas::helper::copy_to_host(q, m_c_gpu, c_m_gpu.data(), buffer_size_c); sb_handle.wait(event); - const bool isAlmostEqual = utils::compare_vectors( - c_m_gpu, c_m_cpu, std::cerr, "", - jm_inType == "half" && jm_outType == "half" ? 3 : 1); + const bool isAlmostEqual = utils::compare_vectors(c_m_gpu, c_m_cpu); ASSERT_TRUE(isAlmostEqual); helper::deallocate(m_a_gpu, q); diff --git a/test/unittest/joint_matrix/tf32_float_16_16_8.cpp b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp index 249c88c34..f7facadc3 100644 --- a/test/unittest/joint_matrix/tf32_float_16_16_8.cpp +++ b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp @@ -34,7 +34,7 @@ const auto SmallMatricesTF32Floatm16n16k8 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(8), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(11, 16, 32, 63), // m ::testing::Values(11, 16, 32, 63), // n @@ -58,10 +58,10 @@ const auto MediumMatricesTF32Floatm16n16k8 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(8), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch - ::testing::Values(65, 127, 234, 511, 768), // m - ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127, 234, 511), // m + ::testing::Values(65, 127, 234, 511), // n ::testing::Values(65, 127), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -82,7 +82,7 @@ const auto LargeMatricesTF32Floatm16n16k8 = ::testing::Combine( ::testing::Values(16), // jm_n ::testing::Values(8), // jm_k ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0, 33), // offset + ::testing::Values(33), // offset ::testing::Values(1), // batch ::testing::Values(1024, 1535, 2024), // m ::testing::Values(1024, 1535, 2024), // n From 50c150b1352a3801b0baac4e660506b285c56e09 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 22 Feb 2024 16:22:48 +0000 Subject: [PATCH 21/21] Reduced the initializer range * Increase error margin --- test/unittest/joint_matrix/joint_matrix_common.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/unittest/joint_matrix/joint_matrix_common.hpp b/test/unittest/joint_matrix/joint_matrix_common.hpp index b6d9d85df..c18366a83 100644 --- a/test/unittest/joint_matrix/joint_matrix_common.hpp +++ b/test/unittest/joint_matrix/joint_matrix_common.hpp @@ -81,9 +81,9 @@ inline void verify_gemm(const joint_matrix_arguments_t arguments) { // initialize the vectors with positive values // to avoid test failures for half precision // accumulation - fill_random_with_range(a_m, scalar_t{1}, scalar_t{3}); - fill_random_with_range(b_m, scalar_t{1}, scalar_t{3}); - fill_random_with_range(c_m_gpu, scalar_t{1}, scalar_t{3}); + fill_random_with_range(a_m, scalar_t{1}, scalar_t{2}); + fill_random_with_range(b_m, scalar_t{1}, scalar_t{2}); + fill_random_with_range(c_m_gpu, scalar_t{1}, scalar_t{2}); } else { fill_random(a_m); fill_random(b_m); @@ -193,7 +193,8 @@ inline void verify_gemm(const joint_matrix_arguments_t arguments) { blas::helper::copy_to_host(q, m_c_gpu, c_m_gpu.data(), buffer_size_c); sb_handle.wait(event); - const bool isAlmostEqual = utils::compare_vectors(c_m_gpu, c_m_cpu); + const bool isAlmostEqual = utils::compare_vectors( + c_m_gpu, c_m_cpu, std::cerr, "\n", jm_outType == "half" ? 3 : 1); ASSERT_TRUE(isAlmostEqual); helper::deallocate(m_a_gpu, q);