diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b13e9d08c7f6..f3fd608894e6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200") # # Supported/expected torch versions for CUDA/ROCm. @@ -172,6 +172,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") # get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) +# +# Get supported FP8 format based on GPU arches +# +get_supported_fp8_format(FP8_FORMAT ${VLLM_GPU_LANG} "${VLLM_GPU_ARCHES}") +if(${FP8_FORMAT} STREQUAL "E4M3FN") + message(STATUS "FP8 format: E4M3FN") + list(APPEND VLLM_GPU_FLAGS "-DUSE_CUDA_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "E4M3FNUZ") + message(STATUS "FP8 format: E4M3FNUZ") + list(APPEND VLLM_GPU_FLAGS "-DUSE_HIP_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "CONFLICT") + message(FATAL_ERROR "Target architectures support different types of FP8 formats!") +endif() + # # Set nvcc parallelism. # diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 24bb7299338ac..3866ba58a6e11 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -435,3 +435,33 @@ function (define_gpu_extension_target GPU_MOD_NAME) install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() + + +# gfx12xx should not be compiled together with gfx94x (MI300) because they support different types of FP8 format. +# FP8_FORMAT will be returned (E4M3FN / E4M3FNUZ / NONE / CONFLICT) +macro (get_supported_fp8_format FP8_FORMAT GPU_LANG GPU_ARCHES) + set(_USING_CUDA_FP8_FORMAT "FALSE") + set(_USING_HIP_FP8_FORMAT "FALSE") + + if (NOT (${GPU_LANG} STREQUAL "HIP")) + set(_USING_CUDA_FP8_FORMAT "TRUE") + else() + foreach (_ARCH ${GPU_ARCHES}) + if (_ARCH MATCHES "gfx94.") + set(_USING_HIP_FP8_FORMAT "TRUE") + elseif(_ARCH MATCHES "gfx12..") + set(_USING_CUDA_FP8_FORMAT "TRUE") + endif() + endforeach() + endif() + + if ((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "NONE") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "TRUE")) + set(FP8_FORMAT "E4M3FNUZ") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "TRUE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "E4M3FN") + else() + set(FP8_FORMAT "CONFLICT") + endif() +endmacro() diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index f2c609c1b68c3..c05ec89f03cc8 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,7 +7,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) #include #include #else @@ -15,7 +15,7 @@ #include #endif -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); @@ -50,7 +50,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, } float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 12b9d97091274..49ff0332a6981 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType -from vllm.utils import is_hip +from vllm.utils import is_hip, is_navi4x logger = init_logger(__name__) @@ -711,7 +711,7 @@ def scaled_fp8_quant( assert (input.ndim == 2) shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() \ + out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() and not is_navi4x() \ else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 633cd5f49fc6a..225b8cdb82c76 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -27,7 +27,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip, is_navi4x, print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -227,8 +227,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight weight_scale = layer.weight_scale - # If rocm, use float8_e4m3fnuz. - if is_hip(): + # If rocm (except Navi4x), use float8_e4m3fnuz. + if is_hip() and not is_navi4x(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -378,9 +378,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype + # If rocm (except Navi4x), use float8_e4m3fnuz as dtype fp8_dtype = torch.float8_e4m3fnuz \ - if is_hip() else torch.float8_e4m3fn + if is_hip() and not is_navi4x() else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -427,8 +427,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if is_hip(): + # If rocm (except Navi4x, which uses e4m3fn), + # normalize the weights and scales to e4m3fnuz + if is_hip() and not is_navi4x(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 447ea2b9348e1..4a39b686245c6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, is_navi4x from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -87,7 +87,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) - self.use_fp8 = isinstance(quant_config, Fp8Config) + self.use_fp8 = isinstance(quant_config, Fp8Config) \ + if is_hip() and not is_navi4x() else False if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -189,8 +190,10 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) + # For CUDA devices and Navi4x, attn_fp8_out will be set to false. self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and is_hip() \ + and not is_navi4x() \ and isinstance(quant_config, Fp8Config) def forward( @@ -225,7 +228,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_fp8 = isinstance(quant_config, Fp8Config) + self.use_fp8 = isinstance(quant_config, Fp8Config) \ + if is_hip() and not is_navi4x() else False rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -456,7 +460,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.layers[layer_idx], nn.Identity): layer_self_attn = self.layers[layer_idx].self_attn - if is_hip(): + # Navi4x quantization should be treated as CUDA devices. + if is_hip() and not is_navi4x(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/utils.py b/vllm/utils.py index 788e0d424ed52..af857ca315b38 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -8,6 +8,7 @@ import ipaddress import os import random +import re import socket import subprocess import sys @@ -425,6 +426,17 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) +def is_navi4x() -> bool: + if not is_hip() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return (archName is not None) and \ + (re.match("gfx12[0-9]{2}", archName) is not None) + + @lru_cache(maxsize=None) def is_cpu() -> bool: from importlib.metadata import PackageNotFoundError, version