diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 16fccae403338..bb3b07e34b399 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __syncthreads(); // For each expert we accumulate the token counts from the different threads. - for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { - tokens_cnts[index(num_experts, 0, eid)] = 0; + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, eid)] += - tokens_cnts[index(num_experts, i - 1, eid)]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } @@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ - for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { - for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) { - expert_ids[i / block_size] = eid; + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; } } @@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel( __syncthreads(); // For each expert we accumulate the token counts from the different threads. - for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { - tokens_cnts[index(num_experts, 0, eid)] = 0; + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, eid)] += - tokens_cnts[index(num_experts, i - 1, eid)]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } @@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel( * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ - for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { - for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) { - expert_ids[i / block_size] = eid; + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; } } @@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = WARP_SIZE; - const int32_t shared_mem = - ((num_thread + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - - // set dynamic shared mem - auto kernel = vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_thread, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); + + // If we have very large number of experts, we can no longer use shared + // memory. + // TODO(simon): the right solution should be calculating the exact right + // amount of shared memory and use that. The num_experts >= 256 is just a + // temporary solution to unblock Deepseek V3. + if (num_experts >= 96) { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + const int32_t mem_tokens_cnts = + ((num_experts + 1) * num_experts) * sizeof(int32_t); + const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); + // allocate global memory + int32_t* tokens_cnts; + int32_t* cumsum; + cudaMalloc(&tokens_cnts, mem_tokens_cnts); + cudaMalloc(&cumsum, mem_cumsum); + + auto kernel = + vllm::moe::moe_align_block_size_global_mem_kernel; + kernel<<<1, num_thread, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), tokens_cnts, cumsum); + cudaFree(tokens_cnts); + cudaFree(cumsum); + }); + } else { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem = + ((num_thread + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<<1, num_thread, shared_mem, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } } void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1fb4965110544..835a6044e6f07 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -166,6 +166,8 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") if self.block_quant: + assert not envs.VLLM_FP8_PADDING, ( + "FP8 weight padding is not supported in block quantization.") tp_size = get_tensor_model_parallel_world_size() assert self.quant_config.weight_block_size is not None block_n, block_k = ( @@ -196,8 +198,9 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype + fp8_dtype = torch.float8_e4m3fn # WEIGHT - weight_dtype = (torch.float8_e4m3fn + weight_dtype = (fp8_dtype if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) @@ -252,6 +255,15 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: + if current_platform.is_rocm() and not is_navi(): + weight, weight_scale, _ = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale) + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale, + requires_grad=False) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -512,6 +524,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: + if current_platform.is_rocm() and not is_navi(): + w13_weight, w13_weight_scale_inv, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale_inv, + layer.w13_input_scale) + w2_weight, w2_weight_scale_inv, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale_inv, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale_inv, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale_inv, requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f3c3e130e4161..b80f9bfbe08fe 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -5,6 +5,9 @@ import triton import triton.language as tl +from vllm.platforms import current_platform +from vllm.utils import is_navi + def apply_w8a8_block_fp8_linear( input: torch.Tensor, @@ -34,10 +37,13 @@ def apply_w8a8_block_fp8_linear( def input_to_float8( x: torch.Tensor, - dtype: torch.dtype = torch.float8_e4m3fn + dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" + if dtype is None: + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -125,7 +131,7 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: torch.dtype = torch.float8_e4m3fn, + dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -140,6 +146,9 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ + if dtype is None: + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}")