17
17
18
18
#define CUMSUM_BLOCK_SIZE 48 // cumsum开销和并行度之间的tradeoff的结果,勿动
19
19
#define CUMSUM_INVALID_TAG -1 // 用于标记无效的cumsum,尝试过-114514但失败了
20
-
20
+ #ifndef MAX_NUM_EXPERTS
21
+ #define MAX_NUM_EXPERTS 32
22
+ #endif
21
23
// 多阶段算法,控制每block处理的行数来权衡额外开销
22
24
// 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block
23
25
// 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成
24
- template <typename X_T,
25
- typename routemap_T,
26
- typename probs_T,
27
- int topk,
28
- int num_experts,
29
- bool has_scale>
26
+ template <typename X_T, typename routemap_T, typename probs_T, bool has_scale>
30
27
__global__ void tokens_unzip_stable_kernel (
31
28
const X_T *__restrict__ X,
32
29
const routemap_T *__restrict__ routemap_topk,
@@ -40,11 +37,13 @@ __global__ void tokens_unzip_stable_kernel(
40
37
const int total_zipped_tokens_num,
41
38
const int max_tokens_per_expert,
42
39
const int token_length,
43
- const int scale_length) {
40
+ const int scale_length,
41
+ const int num_experts,
42
+ const int topk) {
44
43
const int block_row_base = blockIdx .x * CUMSUM_BLOCK_SIZE;
45
- int cumsum_offset[num_experts ];
46
- int expert_offset[num_experts ];
47
- int local_cumsum[num_experts ];
44
+ int cumsum_offset[MAX_NUM_EXPERTS ];
45
+ int expert_offset[MAX_NUM_EXPERTS ];
46
+ int local_cumsum[MAX_NUM_EXPERTS ];
48
47
#pragma unroll
49
48
for (int i = 0 ; i < num_experts; i++) {
50
49
cumsum_offset[i] =
@@ -55,13 +54,13 @@ __global__ void tokens_unzip_stable_kernel(
55
54
local_cumsum[i] = 0 ;
56
55
}
57
56
const int base_row_idx = blockIdx .x * CUMSUM_BLOCK_SIZE;
58
- __shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts ];
59
- __shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][num_experts ];
57
+ __shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS ];
58
+ __shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS ];
60
59
61
60
// --------------------- thread0 单线程任务传递 -------------------------
62
61
if (threadIdx .x == 0 ) [[unlikely]] {
63
- int local_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts ];
64
- probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][num_experts ];
62
+ int local_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS ];
63
+ probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS ];
65
64
#pragma unroll
66
65
for (int i = 0 ; i < CUMSUM_BLOCK_SIZE; i++) {
67
66
#pragma unroll
@@ -171,35 +170,28 @@ void dispatch_tokens_unzip_stable(
171
170
#define GET_DATA (tensor, type ) tensor.data<type>()
172
171
173
172
// 分发处理不同的类型组合
174
- #define DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, TOPK, NUM_EXPERTS, HAS_SCALE ) \
175
- auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
176
- INT_T, \
177
- PROB_T, \
178
- TOPK, \
179
- NUM_EXPERTS, \
180
- HAS_SCALE>; \
181
- kernel<<<grid, block, 0 , X.stream()>>> ( \
182
- GET_DATA (X, TOKEN_T), \
183
- GET_DATA (expert_routemap_topk, INT_T), \
184
- GET_DATA (expert_prob_topk, PROB_T), \
185
- XScale ? XScale->data <float >() : nullptr , \
186
- GET_DATA (X_unzipped, TOKEN_T), \
187
- GET_DATA (zipped_expertwise_rowmap, INT_T), \
188
- GET_DATA (token_prob_unzipped, PROB_T), \
189
- XScale_unzipped.data <float >(), \
190
- global_expertwise_block_cumsum.data <int >(), \
191
- total_zipped_tokens_num, \
192
- max_tokens_per_expert, \
193
- token_length, \
194
- scale_length);
173
+ #define DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE ) \
174
+ auto kernel = tokens_unzip_stable_kernel<TOKEN_T, INT_T, PROB_T, HAS_SCALE>; \
175
+ kernel<<<grid, block, 0 , X.stream()>>> ( \
176
+ GET_DATA (X, TOKEN_T), \
177
+ GET_DATA (expert_routemap_topk, INT_T), \
178
+ GET_DATA (expert_prob_topk, PROB_T), \
179
+ XScale ? XScale->data <float >() : nullptr , \
180
+ GET_DATA (X_unzipped, TOKEN_T), \
181
+ GET_DATA (zipped_expertwise_rowmap, INT_T), \
182
+ GET_DATA (token_prob_unzipped, PROB_T), \
183
+ XScale_unzipped.data <float >(), \
184
+ global_expertwise_block_cumsum.data <int >(), \
185
+ total_zipped_tokens_num, \
186
+ max_tokens_per_expert, \
187
+ token_length, \
188
+ scale_length, \
189
+ num_experts, \
190
+ topk);
195
191
196
192
// 可扩展:处理特定的topk和num_experts组合,可根据之后需求进行扩展
197
193
#define HANDLE_EXPERT_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE ) \
198
- if (topk == 8 && num_experts == 4 ) { \
199
- DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, 8 , 4 , HAS_SCALE) \
200
- } else { \
201
- std::__throw_invalid_argument; \
202
- }
194
+ DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE)
203
195
204
196
#define HANDLE_TOKEN_TYPE (PROB_T, INT_T ) \
205
197
if (DTYPE_CASE (X.dtype (), BFLOAT16)) { \
0 commit comments