Skip to content

Commit

Permalink
remove ck_moe_fused_stage2's weight_scale expand
Browse files Browse the repository at this point in the history
  • Loading branch information
Zzz9990 committed Mar 6, 2025
1 parent 24747cc commit 1e3904c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")

import pdb; pdb.set_trace()
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
Expand Down Expand Up @@ -716,10 +717,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
start += shard_size

if envs.VLLM_USE_AITER_MOE:
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w2_weight.shape[1], -1))
if envs.VLLM_USE_AITER_CK_FUSED_MOE:
max_w13_scales = max_w13_scales.unsqueeze(-1)
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1)
else:
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w2_weight.shape[1], -1))

layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
if envs.VLLM_USE_AITER_CK_FUSED_MOE:
Expand Down

0 comments on commit 1e3904c

Please sign in to comment.