Skip to content

Commit ae560af

Browse files
authored
Support arbitrary num_experts and topk, with bfloat16 zip prob. (#10583)
* Add arbitrary expert_num and topk support for unzip and zip. * Merge bfloat16 zip prob support for flex num_experts and topk
1 parent 07d4241 commit ae560af

File tree

3 files changed

+359
-71
lines changed

3 files changed

+359
-71
lines changed

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
#define CUMSUM_BLOCK_SIZE 48 // cumsum开销和并行度之间的tradeoff的结果,勿动
1919
#define CUMSUM_INVALID_TAG -1 // 用于标记无效的cumsum,尝试过-114514但失败了
20-
20+
#ifndef MAX_NUM_EXPERTS
21+
#define MAX_NUM_EXPERTS 32
22+
#endif
2123
// 多阶段算法,控制每block处理的行数来权衡额外开销
2224
// 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block
2325
// 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成
24-
template <typename X_T,
25-
typename routemap_T,
26-
typename probs_T,
27-
int topk,
28-
int num_experts,
29-
bool has_scale>
26+
template <typename X_T, typename routemap_T, typename probs_T, bool has_scale>
3027
__global__ void tokens_unzip_stable_kernel(
3128
const X_T *__restrict__ X,
3229
const routemap_T *__restrict__ routemap_topk,
@@ -40,11 +37,13 @@ __global__ void tokens_unzip_stable_kernel(
4037
const int total_zipped_tokens_num,
4138
const int max_tokens_per_expert,
4239
const int token_length,
43-
const int scale_length) {
40+
const int scale_length,
41+
const int num_experts,
42+
const int topk) {
4443
const int block_row_base = blockIdx.x * CUMSUM_BLOCK_SIZE;
45-
int cumsum_offset[num_experts];
46-
int expert_offset[num_experts];
47-
int local_cumsum[num_experts];
44+
int cumsum_offset[MAX_NUM_EXPERTS];
45+
int expert_offset[MAX_NUM_EXPERTS];
46+
int local_cumsum[MAX_NUM_EXPERTS];
4847
#pragma unroll
4948
for (int i = 0; i < num_experts; i++) {
5049
cumsum_offset[i] =
@@ -55,13 +54,13 @@ __global__ void tokens_unzip_stable_kernel(
5554
local_cumsum[i] = 0;
5655
}
5756
const int base_row_idx = blockIdx.x * CUMSUM_BLOCK_SIZE;
58-
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts];
59-
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][num_experts];
57+
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
58+
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
6059

6160
// --------------------- thread0 单线程任务传递 -------------------------
6261
if (threadIdx.x == 0) [[unlikely]] {
63-
int local_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts];
64-
probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][num_experts];
62+
int local_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
63+
probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
6564
#pragma unroll
6665
for (int i = 0; i < CUMSUM_BLOCK_SIZE; i++) {
6766
#pragma unroll
@@ -171,35 +170,28 @@ void dispatch_tokens_unzip_stable(
171170
#define GET_DATA(tensor, type) tensor.data<type>()
172171

173172
// 分发处理不同的类型组合
174-
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, TOPK, NUM_EXPERTS, HAS_SCALE) \
175-
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
176-
INT_T, \
177-
PROB_T, \
178-
TOPK, \
179-
NUM_EXPERTS, \
180-
HAS_SCALE>; \
181-
kernel<<<grid, block, 0, X.stream()>>>( \
182-
GET_DATA(X, TOKEN_T), \
183-
GET_DATA(expert_routemap_topk, INT_T), \
184-
GET_DATA(expert_prob_topk, PROB_T), \
185-
XScale ? XScale->data<float>() : nullptr, \
186-
GET_DATA(X_unzipped, TOKEN_T), \
187-
GET_DATA(zipped_expertwise_rowmap, INT_T), \
188-
GET_DATA(token_prob_unzipped, PROB_T), \
189-
XScale_unzipped.data<float>(), \
190-
global_expertwise_block_cumsum.data<int>(), \
191-
total_zipped_tokens_num, \
192-
max_tokens_per_expert, \
193-
token_length, \
194-
scale_length);
173+
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
174+
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, INT_T, PROB_T, HAS_SCALE>; \
175+
kernel<<<grid, block, 0, X.stream()>>>( \
176+
GET_DATA(X, TOKEN_T), \
177+
GET_DATA(expert_routemap_topk, INT_T), \
178+
GET_DATA(expert_prob_topk, PROB_T), \
179+
XScale ? XScale->data<float>() : nullptr, \
180+
GET_DATA(X_unzipped, TOKEN_T), \
181+
GET_DATA(zipped_expertwise_rowmap, INT_T), \
182+
GET_DATA(token_prob_unzipped, PROB_T), \
183+
XScale_unzipped.data<float>(), \
184+
global_expertwise_block_cumsum.data<int>(), \
185+
total_zipped_tokens_num, \
186+
max_tokens_per_expert, \
187+
token_length, \
188+
scale_length, \
189+
num_experts, \
190+
topk);
195191

196192
// 可扩展:处理特定的topk和num_experts组合,可根据之后需求进行扩展
197193
#define HANDLE_EXPERT_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
198-
if (topk == 8 && num_experts == 4) { \
199-
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, 8, 4, HAS_SCALE) \
200-
} else { \
201-
std::__throw_invalid_argument; \
202-
}
194+
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE)
203195

204196
#define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \
205197
if (DTYPE_CASE(X.dtype(), BFLOAT16)) { \

0 commit comments

Comments
 (0)