@@ -41,107 +41,112 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
41
41
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
42
42
gemm_batch_type_t batch_type,
43
43
const typename sb_handle_t ::event_t & _dependencies) {
44
- static constexpr int ClSize = 64 ;
45
- static constexpr int tileWgSize = ClSize / sizeof (element_t );
46
- if (batch_type == gemm_batch_type_t ::interleaved) {
47
- return blas::Gemm_Launcher<
48
- container_0_t , container_1_t , container_2_t , 64 , false , false , false ,
49
- 64 , Tile<4 , 4 , 4 , 4 , 1 , 1 , 1 , 1 , 4 , 4 >, _t_a, _t_b, s_a, s_b,
50
- static_cast <int >(gemm_memory_t ::no_local),
51
- static_cast <int >(gemm_algorithm_t ::standard),
52
- static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 4 ,
53
- static_cast <int >(gemm_batch_type_t ::interleaved)>::
54
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
55
- _b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
56
- batch_size, _dependencies);
57
- }
58
- /* Tall & Skinny matrices. */
59
- #ifdef GEMM_TALL_SKINNY_SUPPORT
60
- if (batch_size == 1 &&
61
- ((_K > 8192 && _M <= 1024 && _N <= 1024 ) ||
62
- (_K > 1024 && _M <= 256 && _N <= 256 )) &&
63
- (!s_a && !s_b)) {
64
- if (_M <= 16 && _N > 32 ) {
65
- return blas::Gemm_Launcher<
66
- container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
67
- ClSize, Tile<1 , 4 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
68
- static_cast <int >(gemm_memory_t ::local),
69
- static_cast <int >(gemm_algorithm_t ::tall_skinny),
70
- static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
71
- static_cast <int >(gemm_batch_type_t ::strided)>::
72
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
73
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
74
- _stridec, batch_size, _dependencies);
75
- } else if (_M > 64 && _N <= 32 ) {
76
- return blas::Gemm_Launcher<
77
- container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
78
- ClSize, Tile<4 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
79
- static_cast <int >(gemm_memory_t ::local),
80
- static_cast <int >(gemm_algorithm_t ::tall_skinny),
81
- static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
82
- static_cast <int >(gemm_batch_type_t ::strided)>::
83
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
84
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
85
- _stridec, batch_size, _dependencies);
86
- } else if (_M <= 16 || _N <= 16 ) {
87
- return blas::Gemm_Launcher<
88
- container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
89
- ClSize, Tile<1 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
90
- static_cast <int >(gemm_memory_t ::local),
91
- static_cast <int >(gemm_algorithm_t ::tall_skinny),
92
- static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
93
- static_cast <int >(gemm_batch_type_t ::strided)>::
94
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
95
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
96
- _stridec, batch_size, _dependencies);
97
- } else if (_M <= 32 || _N <= 32 ) {
98
- return blas::Gemm_Launcher<
99
- container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
100
- ClSize, Tile<2 , 2 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
101
- static_cast <int >(gemm_memory_t ::local),
102
- static_cast <int >(gemm_algorithm_t ::tall_skinny),
103
- static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
104
- static_cast <int >(gemm_batch_type_t ::strided)>::
105
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
106
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
107
- _stridec, batch_size, _dependencies);
108
- } else {
109
- return blas::Gemm_Launcher<
110
- container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
111
- ClSize, Tile<4 , 4 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
112
- static_cast <int >(gemm_memory_t ::local),
113
- static_cast <int >(gemm_algorithm_t ::tall_skinny),
114
- static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
115
- static_cast <int >(gemm_batch_type_t ::strided)>::
116
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
117
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
118
- _stridec, batch_size, _dependencies);
119
- }
120
- } else
121
- #endif // GEMM_TALL_SKINNY_SUPPORT
122
- if (_M * _N <= 65536 ) {
123
- return blas::Gemm_Launcher<
124
- container_0_t , container_1_t , container_2_t , 256 , false , false , false ,
125
- ClSize, Tile<1 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
126
- static_cast <int >(gemm_memory_t ::local),
127
- static_cast <int >(gemm_algorithm_t ::standard),
128
- static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 1 ,
129
- static_cast <int >(gemm_batch_type_t ::strided)>::
130
- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
131
- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
132
- _stridec, batch_size, _dependencies);
133
- } else {
44
+ // Unused configuration cases
45
+ if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
46
+ return _dependencies;
47
+ } else {
48
+ static constexpr int ClSize = 64 ;
49
+ static constexpr int tileWgSize = ClSize / sizeof (element_t );
50
+ if (batch_type == gemm_batch_type_t ::interleaved) {
134
51
return blas::Gemm_Launcher<
135
- container_0_t , container_1_t , container_2_t , 256 , false , false , false ,
136
- ClSize , Tile<4 , 4 , tileWgSize, tileWgSize >, _t_a, _t_b, s_a, s_b,
137
- static_cast <int >(gemm_memory_t ::local ),
52
+ container_0_t , container_1_t , container_2_t , 64 , false , false , false ,
53
+ 64 , Tile<4 , 4 , 4 , 4 , 1 , 1 , 1 , 1 , 4 , 4 >, _t_a, _t_b, s_a, s_b,
54
+ static_cast <int >(gemm_memory_t ::no_local ),
138
55
static_cast <int >(gemm_algorithm_t ::standard),
139
- static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 2 ,
140
- static_cast <int >(gemm_batch_type_t ::strided )>::
56
+ static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 4 ,
57
+ static_cast <int >(gemm_batch_type_t ::interleaved )>::
141
58
template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
142
59
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
143
60
_stridec, batch_size, _dependencies);
144
61
}
62
+ /* Tall & Skinny matrices. */
63
+ #ifdef GEMM_TALL_SKINNY_SUPPORT
64
+ if (batch_size == 1 &&
65
+ ((_K > 8192 && _M <= 1024 && _N <= 1024 ) ||
66
+ (_K > 1024 && _M <= 256 && _N <= 256 )) &&
67
+ (!s_a && !s_b)) {
68
+ if (_M <= 16 && _N > 32 ) {
69
+ return blas::Gemm_Launcher<
70
+ container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
71
+ ClSize, Tile<1 , 4 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
72
+ static_cast <int >(gemm_memory_t ::local),
73
+ static_cast <int >(gemm_algorithm_t ::tall_skinny),
74
+ static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
75
+ static_cast <int >(gemm_batch_type_t ::strided)>::
76
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
77
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
78
+ _stridec, batch_size, _dependencies);
79
+ } else if (_M > 64 && _N <= 32 ) {
80
+ return blas::Gemm_Launcher<
81
+ container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
82
+ ClSize, Tile<4 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
83
+ static_cast <int >(gemm_memory_t ::local),
84
+ static_cast <int >(gemm_algorithm_t ::tall_skinny),
85
+ static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
86
+ static_cast <int >(gemm_batch_type_t ::strided)>::
87
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
88
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
89
+ _stridec, batch_size, _dependencies);
90
+ } else if (_M <= 16 || _N <= 16 ) {
91
+ return blas::Gemm_Launcher<
92
+ container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
93
+ ClSize, Tile<1 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
94
+ static_cast <int >(gemm_memory_t ::local),
95
+ static_cast <int >(gemm_algorithm_t ::tall_skinny),
96
+ static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
97
+ static_cast <int >(gemm_batch_type_t ::strided)>::
98
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
99
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
100
+ _stridec, batch_size, _dependencies);
101
+ } else if (_M <= 32 || _N <= 32 ) {
102
+ return blas::Gemm_Launcher<
103
+ container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
104
+ ClSize, Tile<2 , 2 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
105
+ static_cast <int >(gemm_memory_t ::local),
106
+ static_cast <int >(gemm_algorithm_t ::tall_skinny),
107
+ static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
108
+ static_cast <int >(gemm_batch_type_t ::strided)>::
109
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
110
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
111
+ _stridec, batch_size, _dependencies);
112
+ } else {
113
+ return blas::Gemm_Launcher<
114
+ container_0_t , container_1_t , container_2_t , 256 , true , true , true ,
115
+ ClSize, Tile<4 , 4 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
116
+ static_cast <int >(gemm_memory_t ::local),
117
+ static_cast <int >(gemm_algorithm_t ::tall_skinny),
118
+ static_cast <int >(gemm_vectorization_t ::none), is_beta_zero, 2 ,
119
+ static_cast <int >(gemm_batch_type_t ::strided)>::
120
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
121
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
122
+ _stridec, batch_size, _dependencies);
123
+ }
124
+ } else
125
+ #endif // GEMM_TALL_SKINNY_SUPPORT
126
+ if (_M * _N <= 65536 ) {
127
+ return blas::Gemm_Launcher<
128
+ container_0_t , container_1_t , container_2_t , 256 , false , false ,
129
+ false , ClSize, Tile<1 , 1 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
130
+ s_b, static_cast <int >(gemm_memory_t ::local),
131
+ static_cast <int >(gemm_algorithm_t ::standard),
132
+ static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 1 ,
133
+ static_cast <int >(gemm_batch_type_t ::strided)>::
134
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
135
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
136
+ _stridec, batch_size, _dependencies);
137
+ } else {
138
+ return blas::Gemm_Launcher<
139
+ container_0_t , container_1_t , container_2_t , 256 , false , false ,
140
+ false , ClSize, Tile<4 , 4 , tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
141
+ s_b, static_cast <int >(gemm_memory_t ::local),
142
+ static_cast <int >(gemm_algorithm_t ::standard),
143
+ static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 2 ,
144
+ static_cast <int >(gemm_batch_type_t ::strided)>::
145
+ template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
146
+ _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
147
+ _stridec, batch_size, _dependencies);
148
+ }
149
+ }
145
150
}
146
151
147
152
// Complex Configurations
0 commit comments