Apply RowParallelLinear function for each expert. Variables below num_experts
is the number of experts, and num_experts_per_token
is the number of selected experts for each token.
Tensor parallel is performed along the
After each device in the same communicate world performs linear transformation, it is necessary to perform all reduce on the
Number of experts.
Mark that whether there is bias term. Provide convenience for graph optimization.
If true, we assume that the input is already split across the devices and we do not need to split it again.
If false, input should split into
Input feature of linear transformation.
Shape: input_is_parallel
is True
, where
Contains the offset of the first token for each expert for X
after flattening in dimension *
. Region expert_offset[i+1]
is the prefix sum of tokens from expert_0 to expert_i.
X_flat = X.reshape(-1, K)
for i in range(1, num_expert+1):
X_expert_i = X_flat[expert_offset[i]: expert_offset[i+1]]
Shape:
Transformation weight.
Shape:
Transformation bias.
Shape:
Output feature of linear transformation.
Shape:
enable accumulate with int32 when using int8 linear