From dd1c388f99e862c81c1363b5779f9ab930b53555 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI <104583441+OuadiElfarouki@users.noreply.github.com> Date: Tue, 14 Nov 2023 10:32:13 +0000 Subject: [PATCH] Reduce .so library size (#475) Removed the redundant GEMM configurations instantiated using Cmake --- cmake/CmakeFunctionHelper.cmake | 6 + src/interface/blas3/backend/amd_gpu.hpp | 195 ++++++------ src/interface/blas3/backend/default_cpu.hpp | 101 +++--- src/interface/blas3/backend/intel_gpu.hpp | 327 ++++++++++---------- src/interface/blas3/backend/nvidia_gpu.hpp | 200 ++++++------ 5 files changed, 428 insertions(+), 401 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 2ae71bc5e..3bed39572 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -294,8 +294,14 @@ function(add_gemm_configuration if ((${data} MATCHES "complex") AND (symm_a OR symm_b)) continue() endif() + if (symm_a AND symm_b) + continue() + endif() foreach(trans_a ${boolean_list}) foreach(trans_b ${boolean_list}) + if ((symm_a AND trans_b) OR (symm_b AND trans_a)) + continue() + endif() foreach(is_beta_zero ${boolean_list}) foreach(index ${index_list}) set(file_name "${func}_${double_buffer}_${conflict_a}_" diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index f494f25b9..9d6ffa424 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -41,107 +41,112 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - static constexpr int ClSize = 64; - static constexpr int tileWgSize = ClSize / sizeof(element_t); - if (batch_type == gemm_batch_type_t::interleaved) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::interleaved)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } -/* Tall & Skinny matrices. */ -#ifdef GEMM_TALL_SKINNY_SUPPORT - if (batch_size == 1 && - ((_K > 8192 && _M <= 1024 && _N <= 1024) || - (_K > 1024 && _M <= 256 && _N <= 256)) && - (!s_a && !s_b)) { - if (_M <= 16 && _N > 32) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M > 64 && _N <= 32) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 16 || _N <= 16) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 32 || _N <= 32) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } - } else -#endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else { + // Unused configuration cases + if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { + return _dependencies; + } else { + static constexpr int ClSize = 64; + static constexpr int tileWgSize = ClSize / sizeof(element_t); + if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::interleaved)>:: template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } +/* Tall & Skinny matrices. */ +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (batch_size == 1 && + ((_K > 8192 && _M <= 1024 && _N <= 1024) || + (_K > 1024 && _M <= 256 && _N <= 256)) && + (!s_a && !s_b)) { + if (_M <= 16 && _N > 32) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N <= 32) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (_M <= 16 || _N <= 16) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (_M <= 32 || _N <= 32) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } + } else +#endif // GEMM_TALL_SKINNY_SUPPORT + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, + s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, + false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, + s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } + } } // Complex Configurations diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index e62348363..4fbf341e2 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -41,66 +41,71 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (batch_type == gemm_batch_type_t::interleaved) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::interleaved)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } + // Unused configuration cases + if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { + return _dependencies; + } else { + if (batch_type == gemm_batch_type_t::interleaved) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::interleaved)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } #if defined(NAIVE_GEMM) - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, 64, - Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::naive), - static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); -#else - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } else if (!s_a && !s_b) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), + static_cast(gemm_algorithm_t::naive), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided)>:: template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } +#else + if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (!s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } #endif + } } // Complex Configurations diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 8d788c9b5..77b7e232f 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -40,171 +40,176 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (batch_type == gemm_batch_type_t::interleaved) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::interleaved)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } + // Unused configuration cases + if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { + return _dependencies; + } else { + if (batch_type == gemm_batch_type_t::interleaved) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::interleaved)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } #ifdef GEMM_TALL_SKINNY_SUPPORT - if (!s_a && !s_b) { - /* Tall & Skinny matrices. */ - if (batch_size == 1 && - ((_K >= 4096 && _M * _N <= 16384) || (_K >= 1024 && _M * _N <= 4096))) { - if (_M >= 16 && _N <= 4) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 32, true, true, true, - 64, Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 4 || _N <= 4) { - // Need to increase the work group size for cl::sycl::half for the - // launcher to be instancianted - constexpr int wg_size = sizeof(element_t) == 2 ? 8 : 4; - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 16, true, false, false, - 64, Tile<1, 1, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M >= 16 && _N <= 8) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 32, true, true, true, - 64, Tile<2, 2, 8, 4>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 8 || _N <= 8) { - // Need to increase the work group size for cl::sycl::half for the - // launcher to be instancianted - constexpr int wg_size = sizeof(element_t) == 2 ? 8 : 4; - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 16, true, false, false, - 64, Tile<2, 2, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 16 || _N <= 16) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, true, true, true, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else if (_M <= 32 || _N <= 32) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, true, true, true, - 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else { - constexpr int wg_size = sizeof(element_t) == 8 ? 8 : 16; - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - 64, Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } - } else if (batch_size == 1 && (_t_a || (_t_b && _M * _N > 1048576))) { - if (_M <= 64 || _N <= 64) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, true, true, true, - 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); - } else { - // Need to increase the work group size for double for the - // launcher to be instancianted - constexpr int wg_size = sizeof(element_t) == 8 ? 8 : 16; - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, true, true, true, - 64, Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::tall_skinny), - static_cast(gemm_vectorization_t::none), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, - _stridec, batch_size, _dependencies); + if (!s_a && !s_b) { + /* Tall & Skinny matrices. */ + if (batch_size == 1 && ((_K >= 4096 && _M * _N <= 16384) || + (_K >= 1024 && _M * _N <= 4096))) { + if (_M >= 16 && _N <= 4) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 32, true, true, true, + 64, Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M <= 4 || _N <= 4) { + // Need to increase the work group size for cl::sycl::half for the + // launcher to be instancianted + constexpr int wg_size = sizeof(element_t) == 2 ? 8 : 4; + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 16, true, false, + false, 64, Tile<1, 1, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M >= 16 && _N <= 8) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 32, true, true, true, + 64, Tile<2, 2, 8, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M <= 8 || _N <= 8) { + // Need to increase the work group size for cl::sycl::half for the + // launcher to be instancianted + constexpr int wg_size = sizeof(element_t) == 2 ? 8 : 4; + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 16, true, false, + false, 64, Tile<2, 2, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M <= 16 || _N <= 16) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, true, true, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M <= 32 || _N <= 32) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, true, true, + 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + constexpr int wg_size = sizeof(element_t) == 8 ? 8 : 16; + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, + true, 64, Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } + } else if (batch_size == 1 && (_t_a || (_t_b && _M * _N > 1048576))) { + if (_M <= 64 || _N <= 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, true, true, + 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + // Need to increase the work group size for double for the + // launcher to be instancianted + constexpr int wg_size = sizeof(element_t) == 8 ? 8 : 16; + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, + true, 64, Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } } } - } #endif - if (_M <= 128 && _N <= 128) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, - Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } else if (_t_b && !_t_a && !s_a && !s_b) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::partial), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); + if (_M <= 128 && _N <= 128) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, false, false, + 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else if (_t_b && !_t_a && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } } } diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 13966172e..72ef7160d 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -41,73 +41,126 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (batch_type == gemm_batch_type_t::interleaved) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4, 1, 1, 1, float, float>, _t_a, - _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 4, - static_cast(gemm_batch_type_t::interleaved)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } + // Unused configuration cases + if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) { + return _dependencies; + } else { + if (batch_type == gemm_batch_type_t::interleaved) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 4, + static_cast(gemm_batch_type_t::interleaved)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } #ifdef SB_ENABLE_JOINT_MATRIX - const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); - if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b && - std::is_same::type, float>::value && - std::is_same::type, float>::value && - std::is_same::type, float>::value) { - if (_M > 1024 && _N > 1024) { + const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); + if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b && + std::is_same::type, float>::value && + std::is_same::type, float>::value && + std::is_same::type, float>::value) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<8, 8, 16, 16, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, + float>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, + _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, + _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<4, 8, 16, 8, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, + float>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, + _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, + _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, + float>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, + _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, + _dependencies); + } + } +#endif // SB_ENABLE_JOINT_MATRIX + + if (batch_size > 1) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, true, true, - 128, - Tile<8, 8, 16, 16, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, - float>, + 128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), - true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); - } else if (_M > 64 && _N > 64) { + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); + } else if (_M <= 256 && _N <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 128, false, true, true, - 128, - Tile<4, 8, 16, 8, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, - float>, - _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + 128, Tile<2, 2, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), - true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); - - } else { + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); + } else if (_M <= 1024 && _N <= 1024) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 128, false, true, true, - 128, - Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, 16, 16, 16, cl::sycl::half, - float>, - _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + 128, Tile<4, 4, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), - true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); + } else if (_M <= 2048 && _N <= 2048) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, Tile<8, 8, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); } - } -#endif // SB_ENABLE_JOINT_MATRIX - if (batch_size > 1) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, true, true, 128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, @@ -119,54 +172,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (_M <= 256 && _N <= 256) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 128, false, true, true, - 128, Tile<2, 2, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, - _t_b, s_a, s_b, static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided), - false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); - } else if (_M <= 1024 && _N <= 1024) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 128, false, true, true, - 128, Tile<4, 4, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, - _t_b, s_a, s_b, static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided), - false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); - } else if (_M <= 2048 && _N <= 2048) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 128, false, true, true, - 128, Tile<8, 8, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, - _t_b, s_a, s_b, static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided), - false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, - _dependencies); } - - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, true, true, 128, - Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided), - false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, - _stridea, _b, _ldb, _strideb, _beta, _c, - _ldc, _stridec, batch_size, _dependencies); } // Complex Configurations