Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit dd1c388

Browse files
Reduce .so library size (#475)
Removed the redundant GEMM configurations instantiated using Cmake
1 parent e0562e4 commit dd1c388

File tree

5 files changed

+428
-401
lines changed

5 files changed

+428
-401
lines changed

cmake/CmakeFunctionHelper.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,14 @@ function(add_gemm_configuration
294294
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
295295
continue()
296296
endif()
297+
if (symm_a AND symm_b)
298+
continue()
299+
endif()
297300
foreach(trans_a ${boolean_list})
298301
foreach(trans_b ${boolean_list})
302+
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
303+
continue()
304+
endif()
299305
foreach(is_beta_zero ${boolean_list})
300306
foreach(index ${index_list})
301307
set(file_name "${func}_${double_buffer}_${conflict_a}_"

src/interface/blas3/backend/amd_gpu.hpp

Lines changed: 100 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -41,107 +41,112 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
4141
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
4242
gemm_batch_type_t batch_type,
4343
const typename sb_handle_t::event_t& _dependencies) {
44-
static constexpr int ClSize = 64;
45-
static constexpr int tileWgSize = ClSize / sizeof(element_t);
46-
if (batch_type == gemm_batch_type_t::interleaved) {
47-
return blas::Gemm_Launcher<
48-
container_0_t, container_1_t, container_2_t, 64, false, false, false,
49-
64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
50-
static_cast<int>(gemm_memory_t::no_local),
51-
static_cast<int>(gemm_algorithm_t::standard),
52-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
53-
static_cast<int>(gemm_batch_type_t::interleaved)>::
54-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
55-
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
56-
batch_size, _dependencies);
57-
}
58-
/* Tall & Skinny matrices. */
59-
#ifdef GEMM_TALL_SKINNY_SUPPORT
60-
if (batch_size == 1 &&
61-
((_K > 8192 && _M <= 1024 && _N <= 1024) ||
62-
(_K > 1024 && _M <= 256 && _N <= 256)) &&
63-
(!s_a && !s_b)) {
64-
if (_M <= 16 && _N > 32) {
65-
return blas::Gemm_Launcher<
66-
container_0_t, container_1_t, container_2_t, 256, true, true, true,
67-
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
68-
static_cast<int>(gemm_memory_t::local),
69-
static_cast<int>(gemm_algorithm_t::tall_skinny),
70-
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
71-
static_cast<int>(gemm_batch_type_t::strided)>::
72-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
73-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
74-
_stridec, batch_size, _dependencies);
75-
} else if (_M > 64 && _N <= 32) {
76-
return blas::Gemm_Launcher<
77-
container_0_t, container_1_t, container_2_t, 256, true, true, true,
78-
ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
79-
static_cast<int>(gemm_memory_t::local),
80-
static_cast<int>(gemm_algorithm_t::tall_skinny),
81-
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
82-
static_cast<int>(gemm_batch_type_t::strided)>::
83-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
84-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
85-
_stridec, batch_size, _dependencies);
86-
} else if (_M <= 16 || _N <= 16) {
87-
return blas::Gemm_Launcher<
88-
container_0_t, container_1_t, container_2_t, 256, true, true, true,
89-
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
90-
static_cast<int>(gemm_memory_t::local),
91-
static_cast<int>(gemm_algorithm_t::tall_skinny),
92-
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
93-
static_cast<int>(gemm_batch_type_t::strided)>::
94-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
95-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
96-
_stridec, batch_size, _dependencies);
97-
} else if (_M <= 32 || _N <= 32) {
98-
return blas::Gemm_Launcher<
99-
container_0_t, container_1_t, container_2_t, 256, true, true, true,
100-
ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
101-
static_cast<int>(gemm_memory_t::local),
102-
static_cast<int>(gemm_algorithm_t::tall_skinny),
103-
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
104-
static_cast<int>(gemm_batch_type_t::strided)>::
105-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
106-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
107-
_stridec, batch_size, _dependencies);
108-
} else {
109-
return blas::Gemm_Launcher<
110-
container_0_t, container_1_t, container_2_t, 256, true, true, true,
111-
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
112-
static_cast<int>(gemm_memory_t::local),
113-
static_cast<int>(gemm_algorithm_t::tall_skinny),
114-
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
115-
static_cast<int>(gemm_batch_type_t::strided)>::
116-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
117-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
118-
_stridec, batch_size, _dependencies);
119-
}
120-
} else
121-
#endif // GEMM_TALL_SKINNY_SUPPORT
122-
if (_M * _N <= 65536) {
123-
return blas::Gemm_Launcher<
124-
container_0_t, container_1_t, container_2_t, 256, false, false, false,
125-
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
126-
static_cast<int>(gemm_memory_t::local),
127-
static_cast<int>(gemm_algorithm_t::standard),
128-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
129-
static_cast<int>(gemm_batch_type_t::strided)>::
130-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
131-
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
132-
_stridec, batch_size, _dependencies);
133-
} else {
44+
// Unused configuration cases
45+
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
46+
return _dependencies;
47+
} else {
48+
static constexpr int ClSize = 64;
49+
static constexpr int tileWgSize = ClSize / sizeof(element_t);
50+
if (batch_type == gemm_batch_type_t::interleaved) {
13451
return blas::Gemm_Launcher<
135-
container_0_t, container_1_t, container_2_t, 256, false, false, false,
136-
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
137-
static_cast<int>(gemm_memory_t::local),
52+
container_0_t, container_1_t, container_2_t, 64, false, false, false,
53+
64, Tile<4, 4, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
54+
static_cast<int>(gemm_memory_t::no_local),
13855
static_cast<int>(gemm_algorithm_t::standard),
139-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
140-
static_cast<int>(gemm_batch_type_t::strided)>::
56+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
57+
static_cast<int>(gemm_batch_type_t::interleaved)>::
14158
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
14259
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
14360
_stridec, batch_size, _dependencies);
14461
}
62+
/* Tall & Skinny matrices. */
63+
#ifdef GEMM_TALL_SKINNY_SUPPORT
64+
if (batch_size == 1 &&
65+
((_K > 8192 && _M <= 1024 && _N <= 1024) ||
66+
(_K > 1024 && _M <= 256 && _N <= 256)) &&
67+
(!s_a && !s_b)) {
68+
if (_M <= 16 && _N > 32) {
69+
return blas::Gemm_Launcher<
70+
container_0_t, container_1_t, container_2_t, 256, true, true, true,
71+
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
72+
static_cast<int>(gemm_memory_t::local),
73+
static_cast<int>(gemm_algorithm_t::tall_skinny),
74+
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
75+
static_cast<int>(gemm_batch_type_t::strided)>::
76+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
77+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
78+
_stridec, batch_size, _dependencies);
79+
} else if (_M > 64 && _N <= 32) {
80+
return blas::Gemm_Launcher<
81+
container_0_t, container_1_t, container_2_t, 256, true, true, true,
82+
ClSize, Tile<4, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
83+
static_cast<int>(gemm_memory_t::local),
84+
static_cast<int>(gemm_algorithm_t::tall_skinny),
85+
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
86+
static_cast<int>(gemm_batch_type_t::strided)>::
87+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
88+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
89+
_stridec, batch_size, _dependencies);
90+
} else if (_M <= 16 || _N <= 16) {
91+
return blas::Gemm_Launcher<
92+
container_0_t, container_1_t, container_2_t, 256, true, true, true,
93+
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
94+
static_cast<int>(gemm_memory_t::local),
95+
static_cast<int>(gemm_algorithm_t::tall_skinny),
96+
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
97+
static_cast<int>(gemm_batch_type_t::strided)>::
98+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
99+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
100+
_stridec, batch_size, _dependencies);
101+
} else if (_M <= 32 || _N <= 32) {
102+
return blas::Gemm_Launcher<
103+
container_0_t, container_1_t, container_2_t, 256, true, true, true,
104+
ClSize, Tile<2, 2, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
105+
static_cast<int>(gemm_memory_t::local),
106+
static_cast<int>(gemm_algorithm_t::tall_skinny),
107+
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
108+
static_cast<int>(gemm_batch_type_t::strided)>::
109+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
110+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
111+
_stridec, batch_size, _dependencies);
112+
} else {
113+
return blas::Gemm_Launcher<
114+
container_0_t, container_1_t, container_2_t, 256, true, true, true,
115+
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
116+
static_cast<int>(gemm_memory_t::local),
117+
static_cast<int>(gemm_algorithm_t::tall_skinny),
118+
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 2,
119+
static_cast<int>(gemm_batch_type_t::strided)>::
120+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
121+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
122+
_stridec, batch_size, _dependencies);
123+
}
124+
} else
125+
#endif // GEMM_TALL_SKINNY_SUPPORT
126+
if (_M * _N <= 65536) {
127+
return blas::Gemm_Launcher<
128+
container_0_t, container_1_t, container_2_t, 256, false, false,
129+
false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
130+
s_b, static_cast<int>(gemm_memory_t::local),
131+
static_cast<int>(gemm_algorithm_t::standard),
132+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
133+
static_cast<int>(gemm_batch_type_t::strided)>::
134+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
135+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
136+
_stridec, batch_size, _dependencies);
137+
} else {
138+
return blas::Gemm_Launcher<
139+
container_0_t, container_1_t, container_2_t, 256, false, false,
140+
false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
141+
s_b, static_cast<int>(gemm_memory_t::local),
142+
static_cast<int>(gemm_algorithm_t::standard),
143+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
144+
static_cast<int>(gemm_batch_type_t::strided)>::
145+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
146+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
147+
_stridec, batch_size, _dependencies);
148+
}
149+
}
145150
}
146151

147152
// Complex Configurations

src/interface/blas3/backend/default_cpu.hpp

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -41,66 +41,71 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
4141
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
4242
gemm_batch_type_t batch_type,
4343
const typename sb_handle_t::event_t& _dependencies) {
44-
if (batch_type == gemm_batch_type_t::interleaved) {
45-
return blas::Gemm_Launcher<
46-
container_0_t, container_1_t, container_2_t, 64, false, false, false,
47-
64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
48-
static_cast<int>(gemm_memory_t::no_local),
49-
static_cast<int>(gemm_algorithm_t::standard),
50-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
51-
static_cast<int>(gemm_batch_type_t::interleaved)>::
52-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
53-
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
54-
batch_size, _dependencies);
55-
}
44+
// Unused configuration cases
45+
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
46+
return _dependencies;
47+
} else {
48+
if (batch_type == gemm_batch_type_t::interleaved) {
49+
return blas::Gemm_Launcher<
50+
container_0_t, container_1_t, container_2_t, 64, false, false, false,
51+
64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
52+
static_cast<int>(gemm_memory_t::no_local),
53+
static_cast<int>(gemm_algorithm_t::standard),
54+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
55+
static_cast<int>(gemm_batch_type_t::interleaved)>::
56+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
57+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
58+
_stridec, batch_size, _dependencies);
59+
}
5660
#if defined(NAIVE_GEMM)
57-
return blas::Gemm_Launcher<
58-
container_0_t, container_1_t, container_2_t, 64, false, false, false, 64,
59-
Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
60-
static_cast<int>(gemm_memory_t::no_local),
61-
static_cast<int>(gemm_algorithm_t::naive),
62-
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
63-
static_cast<int>(gemm_batch_type_t::strided)>::
64-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
65-
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
66-
batch_size, _dependencies);
67-
#else
68-
if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) {
69-
return blas::Gemm_Launcher<
70-
container_0_t, container_1_t, container_2_t, 64, false, false, false,
71-
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
72-
static_cast<int>(gemm_memory_t::no_local),
73-
static_cast<int>(gemm_algorithm_t::standard),
74-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
75-
static_cast<int>(gemm_batch_type_t::strided)>::
76-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
77-
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
78-
batch_size, _dependencies);
79-
} else if (!s_a && !s_b) {
8061
return blas::Gemm_Launcher<
8162
container_0_t, container_1_t, container_2_t, 64, false, false, false,
8263
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
8364
static_cast<int>(gemm_memory_t::no_local),
84-
static_cast<int>(gemm_algorithm_t::standard),
65+
static_cast<int>(gemm_algorithm_t::naive),
8566
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
8667
static_cast<int>(gemm_batch_type_t::strided)>::
8768
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
8869
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
8970
batch_size, _dependencies);
90-
} else {
91-
return blas::Gemm_Launcher<
92-
container_0_t, container_1_t, container_2_t, 64, false, false, false,
93-
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
94-
static_cast<int>(gemm_memory_t::local),
95-
static_cast<int>(gemm_algorithm_t::standard),
96-
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
97-
static_cast<int>(gemm_batch_type_t::strided)>::
98-
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
99-
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
100-
batch_size, _dependencies);
101-
}
71+
#else
72+
if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) {
73+
return blas::Gemm_Launcher<
74+
container_0_t, container_1_t, container_2_t, 64, false, false, false,
75+
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
76+
static_cast<int>(gemm_memory_t::no_local),
77+
static_cast<int>(gemm_algorithm_t::standard),
78+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
79+
static_cast<int>(gemm_batch_type_t::strided)>::
80+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
81+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
82+
_stridec, batch_size, _dependencies);
83+
} else if (!s_a && !s_b) {
84+
return blas::Gemm_Launcher<
85+
container_0_t, container_1_t, container_2_t, 64, false, false, false,
86+
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
87+
static_cast<int>(gemm_memory_t::no_local),
88+
static_cast<int>(gemm_algorithm_t::standard),
89+
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
90+
static_cast<int>(gemm_batch_type_t::strided)>::
91+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
92+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
93+
_stridec, batch_size, _dependencies);
94+
} else {
95+
return blas::Gemm_Launcher<
96+
container_0_t, container_1_t, container_2_t, 64, false, false, false,
97+
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
98+
static_cast<int>(gemm_memory_t::local),
99+
static_cast<int>(gemm_algorithm_t::standard),
100+
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
101+
static_cast<int>(gemm_batch_type_t::strided)>::
102+
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
103+
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
104+
_stridec, batch_size, _dependencies);
105+
}
102106

103107
#endif
108+
}
104109
}
105110

106111
// Complex Configurations

0 commit comments

Comments
 (0)