From ed94e3cd742a916054f81f74dcfc02de70e61aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20=C5=9Alusarczyk?= Date: Tue, 15 Apr 2025 14:12:53 +0200 Subject: [PATCH 1/2] sycl: use DNN in the first part of ggml_sycl_mul_mat_batched_sycl --- ggml/src/ggml-sycl/gemm.hpp | 17 ++++++------- ggml/src/ggml-sycl/ggml-sycl.cpp | 41 +++++++++++++++++++------------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 4ebbb5b66fb47..85ceb65a6e9db 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,16 +32,17 @@ class DnnlGemmWrapper { else static_assert(0); } - static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, - const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + static void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q, + dnnl_dim_t batches = 1) { auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - dnnl::memory::dims a_dims = { m, k }; - dnnl::memory::dims b_dims = { k, n }; - dnnl::memory::dims c_dims = { m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + dnnl::memory::dims a_dims = { batches, m, k }; + dnnl::memory::dims b_dims = { batches, k, n }; + dnnl::memory::dims c_dims = { batches, m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::acb : tag::abc); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::acb : tag::abc); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 4d2fda0bfa6ae..09c4d72f9cb24 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1986,7 +1986,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - + GGML_ASSERT(ne00 == ne10); const int64_t row_diff = row_high - row_low; @@ -2737,10 +2737,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream();; @@ -2761,14 +2761,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf : src1_f16_alloc.get(); - char * dst_t; - - dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; - dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; - - // dst strides - size_t nbd2 = dst->nb[2]; - size_t nbd3 = dst->nb[3]; + const dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; + const dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; const float alpha_f32 = 1.0f; const float beta_f32 = 0.0f; @@ -2776,24 +2770,36 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const void * alpha = &alpha_f32; const void * beta = &beta_f32; - dst_t = (char *) dst_ddf; + char * dst_t = (char *) dst_ddf; GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne01 == static_cast(nb1/nb0)); + GGML_ASSERT(ne10 == ne00); // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + const auto r2 = ne12/ne02; + const auto r3 = ne13/ne03; + const auto ne23 = ne12*ne13; if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 +#ifdef GGML_SYCL_DNNL + // TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions + GGML_ASSERT(nb11/nb10 == ne10); + GGML_ASSERT(nb01/nb00 == ne00); + + DnnlGemmWrapper::row_gemm(ctx, false, true, ne11, ne01, ne10, src1_f16, + DnnlGemmWrapper::to_dt(), src0_as_f16, DnnlGemmWrapper::to_dt(), + dst_t, DnnlGemmWrapper::to_dt(), main_stream, ne23); +#else SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t, - cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); + cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type))); +#endif } else { - const int ne23 = ne12*ne13; ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); @@ -2821,7 +2827,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, nb03, nb12_scaled, nb13_scaled, - nbd2, nbd3, r2, r3, item_ct1); + nb2, nb3, r2, r3, item_ct1); }); }); } @@ -3660,7 +3666,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ return GGML_STATUS_SUCCESS; } - sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); + sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); + model_sycl_graph.begin_recording(*(sycl_ctx->stream())); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); From 9c87fde390dfd30f816217007d15db5c44e09f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20=C5=9Alusarczyk?= Date: Fri, 18 Apr 2025 15:58:59 +0200 Subject: [PATCH 2/2] handling the case when nb11/nb10 != ne10 --- ggml/src/ggml-sycl/gemm.hpp | 33 +++++++++++++++++++++++++++----- ggml/src/ggml-sycl/ggml-sycl.cpp | 17 +++++++--------- tests/test-backend-ops.cpp | 4 +++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 85ceb65a6e9db..667c2d40129e6 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,16 +32,30 @@ class DnnlGemmWrapper { else static_assert(0); } - static void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, - const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q, - dnnl_dim_t batches = 1) { + // matrix A has m rows, k columns + // matrix B has k rows, n columns + // nra - number of elements to skip when moving into next row in A + // nrb - number of elements to skip when moving into next row in B + // nca - number of elements to skip when moving into next column in A + // ncb - number of elements to skip when moving into next column in B + // stride_a - number of elements to skip when moving to next A matrix + // stride_b - number of elements to skip when moving to next B matrix + // batches - number of A matrices, equal to number of B matrices + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, + const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches) { + auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); dnnl::memory::dims a_dims = { batches, m, k }; dnnl::memory::dims b_dims = { batches, k, n }; dnnl::memory::dims c_dims = { batches, m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::acb : tag::abc); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::acb : tag::abc); + dnnl::memory::dims a_strides = { stride_a, nra, nca }; + dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; + + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); dnnl::primitive_attr primitive_attr; @@ -64,6 +78,15 @@ class DnnlGemmWrapper { matmul_prim.execute(stream, matmul_args); } + + // matrices A and B are column major, both having k rows + // matrix A has m column, matrix B has n columns + // output: column major matrix C = A transposed * B + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1); + } }; #endif diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 09c4d72f9cb24..33c82cbc6b4f5 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2047,7 +2047,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); @@ -2081,7 +2081,7 @@ inline void ggml_sycl_op_mul_mat_sycl( src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); #else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); #endif @@ -2784,14 +2784,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 -#ifdef GGML_SYCL_DNNL - // TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions - GGML_ASSERT(nb11/nb10 == ne10); - GGML_ASSERT(nb01/nb00 == ne00); - - DnnlGemmWrapper::row_gemm(ctx, false, true, ne11, ne01, ne10, src1_f16, - DnnlGemmWrapper::to_dt(), src0_as_f16, DnnlGemmWrapper::to_dt(), - dst_t, DnnlGemmWrapper::to_dt(), main_stream, ne23); +#if GGML_SYCL_DNNL + DnnlGemmWrapper::gemm(ctx, ne11, ne01, ne10, + src1_f16, DnnlGemmWrapper::to_dt(), nb11/nb10, 1, nb12/nb10, + src0_as_f16, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, + dst_t, DnnlGemmWrapper::to_dt(), main_stream, ne23); #else SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a5741c8d959d..3e872c9cbf31e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3865,7 +3865,7 @@ static const ggml_type other_types[] = { // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low static std::vector> make_test_cases_eval() { std::vector> test_cases; - std::default_random_engine rng(0); + [[maybe_unused]] std::default_random_engine rng(0); // unary ops for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { @@ -4182,6 +4182,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 1}, {1, 1}, {0, 2, 1, 3})); } } for (ggml_type type_a : other_types) {