Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepseek V3 support #364

Merged
merged 11 commits into from
Jan 16, 2025
104 changes: 71 additions & 33 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, eid)] +=
tokens_cnts[index(num_experts, i - 1, eid)];
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

Expand All @@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
expert_ids[i / block_size] = eid;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

Expand Down Expand Up @@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel(
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, eid)] +=
tokens_cnts[index(num_experts, i - 1, eid)];
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

Expand All @@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel(
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
expert_ids[i / block_size] = eid;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

Expand Down Expand Up @@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = WARP_SIZE;
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});

// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 96) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);

auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}
}

void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
Expand Down
38 changes: 37 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def create_weights(
weight_loader = extra_weight_attrs.get("weight_loader")

if self.block_quant:
assert not envs.VLLM_FP8_PADDING, (
"FP8 weight padding is not supported in block quantization.")
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
Expand Down Expand Up @@ -196,8 +198,9 @@ def create_weights(
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

fp8_dtype = torch.float8_e4m3fn
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
weight_dtype = (fp8_dtype
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)

Expand Down Expand Up @@ -252,6 +255,15 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm() and not is_navi():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale,
requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -512,6 +524,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm() and not is_navi():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.utils import is_navi


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -34,10 +37,13 @@ def apply_w8a8_block_fp8_linear(

def input_to_float8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
if dtype is None:
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
and not is_navi() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
Expand Down Expand Up @@ -125,7 +131,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
Expand All @@ -140,6 +146,9 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
if dtype is None:
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
and not is_navi() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
Expand Down
Loading