Skip to content

Commit

Permalink
Merge branch 'aiter_integration_final' into aiter_integration_ck_fuse…
Browse files Browse the repository at this point in the history
…d_moe
  • Loading branch information
Zzz9990 committed Mar 6, 2025
2 parents cdeb54e + c0dd5ad commit 83e92a6
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ ENV TOKENIZERS_PARALLELISM=false
ENV HIP_FORCE_DEV_KERNARG=1

# Enable Aiter. Make sure this only exists on the aiter branch.
ENV VLLM_USE_AITER=1
# ENV VLLM_USE_AITER=1

CMD ["/bin/bash"]

9 changes: 3 additions & 6 deletions Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="485b4b28"
ARG AITER_BRANCH="41297e56"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down Expand Up @@ -118,17 +118,14 @@ RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
FROM base AS build_aiter
ARG AITER_BRANCH
ARG AITER_REPO
COPY requirements-rocm.txt /app
COPY requirements-common.txt /app
RUN pip install -r requirements-rocm.txt
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
&& pip install -r requirements.txt
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install

FROM base AS final
Expand Down
2 changes: 1 addition & 1 deletion csrc/rocm/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,7 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \
s_b, __wvPrGrp, Otp_in, CuCount); \
} else { \
Expand Down
7 changes: 4 additions & 3 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.utils import aiter_paged_attn_enabled

if envs.VLLM_USE_AITER_PAGED_ATTN:
if aiter_paged_attn_enabled():
from vllm.attention.ops.paged_attn_aiter import (PagedAttention,
PagedAttentionMetadata)
else:
Expand Down Expand Up @@ -616,7 +617,7 @@ def forward(
else:
assert value is None

if (envs.VLLM_USE_AITER_PAGED_ATTN and kv_cache.dtype.itemsize == 1
if (aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
Expand Down Expand Up @@ -804,7 +805,7 @@ def forward(
use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
if envs.VLLM_USE_AITER_PAGED_ATTN:
if aiter_paged_attn_enabled():
out = output[num_prefill_tokens:]
PagedAttention.forward_decode(
decode_query,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import aiter_moe_enabled, direct_register_custom_op

if envs.VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
import aiter

logger = init_logger(__name__)
Expand Down Expand Up @@ -950,7 +950,7 @@ def fused_topk(
dtype=torch.int32,
device=hidden_states.device)

if envs.VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
aiter.topk_softmax(topk_weights, topk_ids, token_expert_indicies,
gating_output.float(), renormalize)
else:
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.envs import VLLM_USE_AITER_MOE
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import aiter_moe_enabled

if VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
from aiter import ck_moe
from aiter.ops.shuffle import shuffle_weight

Expand Down Expand Up @@ -101,7 +101,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

if envs.VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight.data),
requires_grad=False)
Expand Down Expand Up @@ -189,7 +189,7 @@ def forward_cuda(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

if VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
return ck_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
import torch.nn as nn

from vllm.envs import VLLM_USE_AITER_NORM
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import aiter_norm_enabled

if VLLM_USE_AITER_NORM:
if aiter_norm_enabled():
import aiter


Expand Down Expand Up @@ -100,7 +100,7 @@ def forward_cuda(
return out

if residual is not None:
if VLLM_USE_AITER_NORM:
if aiter_norm_enabled():
aiter.rmsnorm2d_fwd_with_add(
x,
x,
Expand All @@ -118,7 +118,7 @@ def forward_cuda(
)
return x, residual

if VLLM_USE_AITER_NORM:
if aiter_norm_enabled():
out = aiter.rms_norm(x, self.weight.data, self.variance_epsilon)
else:
out = torch.empty_like(x)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import envs
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand All @@ -16,8 +15,9 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.utils import aiter_linear_enabled

if envs.VLLM_USE_AITER_LINEAR:
if aiter_linear_enabled():
from aiter.tuned_gemm import tgemm
else:
from vllm.model_executor.layers.tuned_gemm import tgemm
Expand Down Expand Up @@ -256,7 +256,7 @@ def forward(
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
if type(self.quant_method
) is UnquantizedLinearMethod and envs.VLLM_USE_AITER_LINEAR:
) is UnquantizedLinearMethod and aiter_linear_enabled():
output = tgemm.mm(x, self.weight, bias, self.out_dtype)
else:
output = self.quant_method.apply(self, x, bias)
Expand Down
36 changes: 18 additions & 18 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_navi
from vllm.utils import aiter_moe_enabled, aiter_2stage_moe_enabled, is_navi

if envs.VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
from aiter.fused_moe_bf16_asm import asm_moe
if envs.VLLM_USE_AITER_2STAGE_MOE:
if aiter_2stage_moe_enabled():
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight

Expand Down Expand Up @@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
if envs.VLLM_USE_AITER_MOE:
if aiter_moe_enabled():
w13_scales = layer.w13_weight_scale.data.unsqueeze(
-1).unsqueeze(-1).expand(
(-1, layer.w13_weight.shape[1], -1))
Expand All @@ -632,7 +632,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)

if envs.VLLM_USE_AITER_2STAGE_MOE:
if aiter_2stage_moe_enabled():
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight, layout=(32, 32)),
requires_grad=False)
Expand Down Expand Up @@ -715,32 +715,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
dq_weight, max_w13_scales[expert_id])
start += shard_size

if envs.VLLM_USE_AITER_MOE:
if envs.VLLM_USE_AITER_2STAGE_MOE:
if aiter_moe_enabled():
if aiter_2stage_moe_enabled():
max_w13_scales = max_w13_scales.unsqueeze(-1)
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1)
else:
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
-1).expand((-1, layer.w2_weight.shape[1], -1))

layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
if envs.VLLM_USE_AITER_2STAGE_MOE:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight, layout=(32, 32)),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight, layout=(32, 32)),
requires_grad=False)
else:
max_w13_scales = max_w13_scales.unsqueeze(
-1).unsqueeze(-1).expand((
-1, layer.w13_weight.shape[1], -1))
w2_scales = layer.w2_weight_scale.data.unsqueeze(
-1).unsqueeze(-1).expand((
-1, layer.w2_weight.shape[1], -1))
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
layer.w13_weight),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
layer.w2_weight),
requires_grad=False)

layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales.contiguous(), requires_grad=False)
return
Expand Down Expand Up @@ -776,8 +776,8 @@ def apply(
e_score_correction_bias=e_score_correction_bias,
)

if envs.VLLM_USE_AITER_MOE:
if envs.VLLM_USE_AITER_2STAGE_MOE:
if aiter_moe_enabled():
if aiter_2stage_moe_enabled():
return ck_moe_2stages(a1=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
from vllm.utils import is_mi250, is_navi
from vllm.utils import aiter_linear_enabled, is_mi250, is_navi

if envs.VLLM_USE_AITER_LINEAR:
if aiter_linear_enabled():
from aiter.tuned_gemm import tgemm as aiter_tgemm

support_tuned_gemms = False
Expand Down Expand Up @@ -105,7 +105,7 @@ def scaled_mm(
scale_b: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
if envs.VLLM_USE_AITER_LINEAR:
if aiter_linear_enabled():
return aiter_tgemm.mm(inp,
weight.t(),
otype=out_dtype,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torch import nn
from transformers import MixtralConfig

from vllm import envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
Expand All @@ -48,6 +47,7 @@
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import aiter_linear_enabled

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate",
out_dtype=torch.float32 if envs.VLLM_USE_AITER_LINEAR else None,
out_dtype=torch.float32 if aiter_linear_enabled() else None,
)

self.experts = FusedMoE(num_experts=num_experts,
Expand Down
6 changes: 1 addition & 5 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
or envs.VLLM_USE_AITER_MOE
or envs.VLLM_USE_AITER_NORM
or envs.VLLM_USE_AITER_PAGED_ATTN):
logger.info("Aiter main switch - VLLM_USE_AITER is not set,"
logger.info("Aiter main switch (VLLM_USE_AITER) is not set."
" Disabling individual Aiter components")
envs.VLLM_USE_AITER_LINEAR = False
envs.VLLM_USE_AITER_MOE = False
envs.VLLM_USE_AITER_NORM = False
envs.VLLM_USE_AITER_PAGED_ATTN = False

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
Expand Down
25 changes: 25 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,31 @@ def is_navi3() -> bool:
return archName is not None and "gfx11" in archName


@cache
def aiter_moe_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_MOE


@cache
def aiter_2stage_moe_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_2STAGE_MOE


@cache
def aiter_paged_attn_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_PAGED_ATTN


@cache
def aiter_linear_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_LINEAR


@cache
def aiter_norm_enabled() -> bool:
return envs.VLLM_USE_AITER and envs.VLLM_USE_AITER_NORM


def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
Expand Down

0 comments on commit 83e92a6

Please sign in to comment.