Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aiter integration ck fused moe #459

Merged
merged 5 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
VLLM_USE_AITER: bool = False
VLLM_USE_AITER_MOE: bool = True
VLLM_USE_AITER_2STAGE_MOE: bool = True
VLLM_USE_AITER_PAGED_ATTN: bool = False
VLLM_USE_AITER_LINEAR: bool = True
VLLM_USE_AITER_NORM: bool = True
Expand Down Expand Up @@ -301,6 +302,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_USE_AITER_MOE":
lambda: (os.getenv("VLLM_USE_AITER_MOE", "True").lower() in ("true", "1")),

# use ater ck fused moe op if ater ops are enabled
"VLLM_USE_AITER_2STAGE_MOE":
lambda: (os.getenv("VLLM_USE_AITER_2STAGE_MOE", "True").lower() in
("true", "1")),

# use ater paged attn op if ater ops are enabled
"VLLM_USE_AITER_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_AITER_PAGED_ATTN", "False").lower() in
Expand Down
66 changes: 49 additions & 17 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import aiter_moe_enabled, is_navi
from vllm.utils import aiter_2stage_moe_enabled, aiter_moe_enabled, is_navi

if aiter_moe_enabled():
from aiter.fused_moe_bf16_asm import asm_moe
if aiter_2stage_moe_enabled():
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight

ACTIVATION_SCHEMES = ["static", "dynamic"]
Expand Down Expand Up @@ -629,12 +631,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
w2_scales.contiguous(), requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight),
requires_grad=False)

if aiter_2stage_moe_enabled():
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight, layout=(32, 32)),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight, layout=(32, 32)),
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight),
requires_grad=False)
return

# If checkpoint is fp8, we need to handle that the
Expand Down Expand Up @@ -705,18 +716,30 @@ def process_weights_after_loading(self, layer: Module) -> None:
start += shard_size

if aiter_moe_enabled():
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w2_weight.shape[1], -1))
if aiter_2stage_moe_enabled():
max_w13_scales = max_w13_scales.unsqueeze(-1)
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1)
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight, layout=(32, 32)),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight, layout=(32, 32)),
requires_grad=False)
else:
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(
-1).unsqueeze(-1).expand(
(-1, layer.w2_weight.shape[1], -1))
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight),
requires_grad=False)

layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales.contiguous(), requires_grad=False)
return
Expand Down Expand Up @@ -753,6 +776,15 @@ def apply(
)

if aiter_moe_enabled():
if aiter_2stage_moe_enabled():
return ck_moe_2stages(a1=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=layer.w13_weight_scale,
fc2_scale=layer.w2_weight_scale)

return asm_moe(
hidden_states=x,
w1=layer.w13_weight,
Expand Down
5 changes: 5 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,11 @@ def aiter_moe_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_MOE


@cache
def aiter_2stage_moe_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_2STAGE_MOE


@cache
def aiter_paged_attn_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_PAGED_ATTN
Expand Down
Loading