Apply ColumnParallelLinear function for each expert. Variable num_experts
is the number of experts, and num_experts_per_token
is the number of selected experts for each token.
Tensor parallel is performed along the
Number of experts.
Mark that whether there is bias term. Provide convenience for graph optimization.
Do all gather on output and make Y avaiable to all devices, otherwise, every device
Input feature of linear transformation.
Shape:
Contains the offset of the first token for each expert for X
after flattening in dimension *
. Region expert_offset[i+1]
is the prefix sum of tokens from expert_0 to expert_i.
X_flat = X.reshape(-1, K)
for i in range(1, num_expert+1):
X_expert_i = X_flat[expert_offset[i]: expert_offset[i+1]]
Shape:
Transformation weight of all experts.
Shape:
Transformation bias of all experts.
Shape:
Output feature of linear transformation.
Shape: gather_output
is False
, where
enable accumulate with int32 when using int8 linear