Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Reduce .so library size (#475)
Browse files Browse the repository at this point in the history
Removed the redundant GEMM configurations instantiated using Cmake
  • Loading branch information
OuadiElfarouki authored Nov 14, 2023
1 parent e0562e4 commit dd1c388
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 401 deletions.
6 changes: 6 additions & 0 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,14 @@ function(add_gemm_configuration
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
continue()
endif()
foreach(is_beta_zero ${boolean_list})
foreach(index ${index_list})
set(file_name "${func}_${double_buffer}_${conflict_a}_"
Expand Down
195 changes: 100 additions & 95 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,107 +41,112 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
}
/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 &&
((_K > 8192 && _M <= 1024 && _N <= 1024) ||
(_K > 1024 && _M <= 256 && _N <= 256)) &&
(!s_a && !s_b)) {
if (_M <= 16 && _N > 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M > 64 && _N <= 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M <= 16 || _N <= 16) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M <= 32 || _N <= 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
} else
#endif // GEMM_TALL_SKINNY_SUPPORT
if (_M * _N <= 65536) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false, false,
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
} else {
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false, false,
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 &&
((_K > 8192 && _M <= 1024 && _N <= 1024) ||
(_K > 1024 && _M <= 256 && _N <= 256)) &&
(!s_a && !s_b)) {
if (_M <= 16 && _N > 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M > 64 && _N <= 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M <= 16 || _N <= 16) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (_M <= 32 || _N <= 32) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
} else
#endif // GEMM_TALL_SKINNY_SUPPORT
if (_M * _N <= 65536) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
s_b, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
}
}

// Complex Configurations
Expand Down
101 changes: 53 additions & 48 deletions src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,66 +41,71 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
}
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
} else {
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
#if defined(NAIVE_GEMM)
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false, 64,
Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::naive),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
#else
if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
} else if (!s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_algorithm_t::naive),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
}
#else
if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (!s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}

#endif
}
}

// Complex Configurations
Expand Down
Loading

0 comments on commit dd1c388

Please sign in to comment.