diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index f5b35efa..fa1cb971 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -16,8 +16,6 @@ #ifndef FLASHINFER_SAMPLING_CUH_ #define FLASHINFER_SAMPLING_CUH_ -#include - #include #include #include @@ -347,13 +345,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, } __syncthreads(); if (tx == 0) { + output[bx] = sampled_id; if (temp_storage.data.block_aggregate.pair.count >= k) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false; } } else { - output[bx] = sampled_id; if (success != nullptr) { success[bx] = true; } @@ -433,13 +431,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } __syncthreads(); if (tx == 0) { + output[bx] = sampled_id; if (float(q) >= top_p) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false; } } else { - output[bx] = sampled_id; if (success != nullptr) { success[bx] = true; } @@ -539,13 +537,13 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } __syncthreads(); if (tx == 0) { + output[bx] = sampled_id; if (pivot < scaled_p) { // failed to sample within MAX_ROUNDS if (success != nullptr) { success[bx] = false; } } else { - output[bx] = sampled_id; if (success != nullptr) { success[bx] = true; } @@ -627,13 +625,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp } __syncthreads(); if (tx == 0) { + output[bx] = sampled_id; if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false; } } else { - output[bx] = sampled_id; if (success != nullptr) { success[bx] = true; } @@ -808,7 +806,7 @@ struct RenormTempStorage { template __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* top_p_arr, - float top_p_val, float eps, uint32_t d) { + float top_p_val, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; @@ -844,12 +842,20 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* threadlocal_max_val = temp_storage.data.max_val; float low = 0, high = threadlocal_max_val; + DType min_gt_low, max_le_high; DType sum_low(1); - // f(x) = probs[probs > x], f(x) is non-increasing - // loop invariant: f(low) >= p, f(high) < p - while (high - low > eps) { + // f(x) = sum(probs[probs > x]), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= p, f(high) < p + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition + // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p + do { DType threadlocal_sum(0); float mid = (low + high) / 2; + min_gt_low = high; + max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -858,26 +864,42 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType(0); + if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, probs_vec[j]); + } } threadlocal_sum += BlockReduce(temp_storage.block_prim.reduce) .Sum(probs_greater_than_pivot); __syncthreads(); } + min_gt_low = BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, cub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, cub::Max()); if (tx == 0) { temp_storage.data.block_aggregate.value = threadlocal_sum; + temp_storage.data.min_val = min_gt_low; + temp_storage.data.max_val = max_le_high; } __syncthreads(); threadlocal_sum = temp_storage.data.block_aggregate.value; + min_gt_low = temp_storage.data.min_val; + max_le_high = temp_storage.data.max_val; if (threadlocal_sum >= p) { low = mid; sum_low = float(threadlocal_sum); } else { - high = mid; + high = min(mid, max_le_high); } - } + } while (min_gt_low != max_le_high); - DType normalizer = math::ptx_rcp(max(sum_low, eps)); + DType normalizer = math::ptx_rcp(max(sum_low, 1e-8)); // normalize for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -898,7 +920,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* template __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr, - uint32_t top_k_val, float eps, uint32_t d) { + uint32_t top_k_val, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -941,12 +963,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType threadlocal_min_val = temp_storage.data.min_val; float low = threadlocal_min_val - 1, high = threadlocal_max_val; + DType min_gt_low, max_le_high; // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // loop invariant: f(low) >= k, f(high) < k - while (high - low > eps) { + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do { int threadlocal_count_sum = 0; int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0 float mid = (low + high) / 2; + min_gt_low = high; + max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -956,23 +986,41 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot_count[j] = logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; + if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, logits_vec[j]); + } + if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, logits_vec[j]); + } } threadlocal_count_sum += BlockReduce(temp_storage.block_prim.reduce_int) .Sum(probs_greater_than_pivot_count); __syncthreads(); } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, cub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, cub::Max()); + __syncthreads(); if (tx == 0) { temp_storage.data.block_aggregate.count = threadlocal_count_sum; + temp_storage.data.min_val = min_gt_low; + temp_storage.data.max_val = max_le_high; } __syncthreads(); threadlocal_count_sum = temp_storage.data.block_aggregate.count; + min_gt_low = temp_storage.data.min_val; + max_le_high = temp_storage.data.max_val; if (threadlocal_count_sum >= k) { low = mid; } else { - high = mid; + high = min(mid, max_le_high); } - } + } while (min_gt_low != max_le_high); pivot = low; } @@ -996,7 +1044,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType template __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t top_k_val, float eps, uint32_t d) { + uint32_t top_k_val, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -1033,13 +1081,21 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* threadlocal_max_val = temp_storage.data.max_val; float low = 0, high = threadlocal_max_val; + DType min_gt_low, max_le_high; DType sum_low(1); // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // loop invariant: f(low) >= k, f(high) < k - while (high - low > eps) { + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do { Pair threadlocal_sum{DType(0), 0}; Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 float mid = (low + high) / 2; + min_gt_low = high; + max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1050,26 +1106,44 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* probs_greater_than_pivot_pair[j] = { (probs_vec[j] > mid) ? probs_vec[j] : DType(0), (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, probs_vec[j]); + } } threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_pair) .Sum(probs_greater_than_pivot_pair); __syncthreads(); } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, cub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, cub::Max()); + __syncthreads(); if (tx == 0) { temp_storage.data.block_aggregate.pair = threadlocal_sum; + temp_storage.data.min_val = min_gt_low; + temp_storage.data.max_val = max_le_high; } __syncthreads(); threadlocal_sum = temp_storage.data.block_aggregate.pair; + min_gt_low = temp_storage.data.min_val; + max_le_high = temp_storage.data.max_val; if (threadlocal_sum.count >= k) { low = mid; sum_low = float(threadlocal_sum.value); } else { - high = mid; + high = min(mid, max_le_high); } - } + } while (min_gt_low != max_le_high); - normalizer = math::ptx_rcp(max(sum_low, eps)); + normalizer = math::ptx_rcp(max(sum_low, 1e-8)); pivot = low; } @@ -1090,7 +1164,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* } template -cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, float eps, +cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; @@ -1099,7 +1173,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, &d}; + void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopPRenormProbKernel; FLASHINFER_CUDA_CALL( @@ -1110,7 +1184,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, } template -cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps, +cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; @@ -1119,7 +1193,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, &d}; + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopKRenormProbKernel; FLASHINFER_CUDA_CALL( @@ -1130,7 +1204,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr } template -cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, float eps, +cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; @@ -1139,7 +1213,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, &d}; + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopKMaskLogitsKernel; FLASHINFER_CUDA_CALL( diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 3d0b678d..d3ef6193 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -59,13 +59,13 @@ std::vector top_k_top_p_sampling_from_probs( std::optional maybe_top_p_arr, double top_p_val, bool deterministic); torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val, double eps); + double top_p_val); torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, double eps); + unsigned int top_k_val); torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val, double eps); + unsigned int top_k_val); std::vector chain_speculative_sampling( torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index dea531e0..745f4abe 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -221,7 +221,7 @@ std::vector top_k_top_p_sampling_from_probs( } torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val, double eps) { + double top_p_val) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -244,15 +244,15 @@ torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, eps, batch_size, - top_p_val, vocab_size, torch_current_stream); + has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, batch_size, top_p_val, + vocab_size, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); return renorm_probs; } torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, double eps) { + unsigned int top_k_val) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -275,7 +275,7 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, eps, batch_size, top_k_val, + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, torch_current_stream); TORCH_CHECK(status == cudaSuccess, @@ -284,7 +284,7 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, double eps) { + unsigned int top_k_val) { CHECK_INPUT(logits); auto device = logits.device(); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) @@ -307,7 +307,7 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional( static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, eps, batch_size, top_k_val, + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, torch_current_stream); TORCH_CHECK(status == cudaSuccess, diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 290d4461..f83a8104 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -39,7 +39,10 @@ def _to_tensor_scalar_tuple(x): def sampling_from_probs( - probs: torch.Tensor, uniform_samples: torch.Tensor, deterministic: bool = True + probs: torch.Tensor, + uniform_samples: torch.Tensor, + deterministic: bool = True, + check_nan: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from probabilities. @@ -52,6 +55,8 @@ def sampling_from_probs( Expected to be uniformly distributed in ``[0, 1)``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -82,6 +87,9 @@ def sampling_from_probs( ----- This function expects float32 inputs, and the output is int32. """ + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.sampling_from_probs(probs, uniform_samples, deterministic) @@ -90,6 +98,7 @@ def top_p_sampling_from_probs( uniform_samples: torch.Tensor, top_p: Union[torch.Tensor, float], deterministic: bool = True, + check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -111,6 +120,8 @@ def top_p_sampling_from_probs( If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -150,6 +161,9 @@ def top_p_sampling_from_probs( We encourage users to set ``max_top_p_rounds`` to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping. """ + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.top_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic ) @@ -160,6 +174,7 @@ def top_k_sampling_from_probs( uniform_samples: torch.Tensor, top_k: Union[torch.Tensor, int], deterministic: bool = True, + check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -181,6 +196,8 @@ def top_k_sampling_from_probs( If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -220,6 +237,9 @@ def top_k_sampling_from_probs( We encourage users to set ``max_top_k_rounds`` to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping. """ + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.top_k_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic ) @@ -230,6 +250,7 @@ def min_p_sampling_from_probs( uniform_samples: torch.Tensor, min_p: Union[torch.Tensor, float], deterministic: bool = True, + check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, @@ -252,6 +273,8 @@ def min_p_sampling_from_probs( If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -292,6 +315,9 @@ def min_p_sampling_from_probs( We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping. """ + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.min_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic ) @@ -304,7 +330,7 @@ def top_k_top_p_sampling_from_logits( top_p: Union[torch.Tensor, float], filter_apply_order: str = "top_k_first", deterministic: bool = True, - **kwargs, + check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, @@ -335,6 +361,8 @@ def top_k_top_p_sampling_from_logits( If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -385,11 +413,16 @@ def top_k_top_p_sampling_from_logits( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ if filter_apply_order == "top_k_first": - masked_logits = top_k_mask_logits(probs, top_k, **kwargs) + masked_logits = top_k_mask_logits(probs, top_k) probs = torch.softmax(masked_logits, dim=-1) - return top_p_sampling_from_probs(probs, uniform_samples, top_p, deterministic) + return top_p_sampling_from_probs( + probs, uniform_samples, top_p, deterministic, check_nan=check_nan + ) elif filter_apply_order == "joint": probs = torch.softmax(probs, dim=-1) + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.top_k_top_p_sampling_from_probs( probs, uniform_samples, @@ -408,7 +441,7 @@ def top_k_top_p_sampling_from_probs( top_p: Union[torch.Tensor, float], filter_apply_order: str = "top_k_first", deterministic: bool = True, - **kwargs, + check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-k and top-p sampling from probabilities, @@ -439,6 +472,8 @@ def top_k_top_p_sampling_from_probs( If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. Returns ------- @@ -480,11 +515,14 @@ def top_k_top_p_sampling_from_probs( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ if filter_apply_order == "top_k_first": - renorm_probs = top_k_renorm_prob(probs, top_k, **kwargs) + renorm_probs = top_k_renorm_prob(probs, top_k) return top_p_sampling_from_probs( - renorm_probs, uniform_samples, top_p, deterministic + renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan ) elif filter_apply_order == "joint": + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") return _kernels.top_k_top_p_sampling_from_probs( probs, uniform_samples, @@ -497,7 +535,8 @@ def top_k_top_p_sampling_from_probs( def top_p_renorm_prob( - probs: torch.Tensor, top_p: Union[torch.Tensor, float], eps: float = 1e-6 + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. @@ -512,8 +551,6 @@ def top_p_renorm_prob( If a tensor, each request has its own threshold. We mask out the probabilities less than `threshold` where the cumulative sum of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. - eps: float - The epsilon value for numerical stability. Returns ------- @@ -523,11 +560,12 @@ def top_p_renorm_prob( This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to ``top_p_sampling_from_probs``. """ - return _kernels.top_p_renorm_prob(probs, *_to_tensor_scalar_tuple(top_p), eps) + return _kernels.top_p_renorm_prob(probs, *_to_tensor_scalar_tuple(top_p)) def top_k_renorm_prob( - probs: torch.Tensor, top_k: Union[torch.Tensor, int], eps: float = 1e-6 + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. @@ -541,8 +579,6 @@ def top_k_renorm_prob( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. - eps: float - The epsilon value for numerical stability. Returns ------- @@ -554,11 +590,11 @@ def top_k_renorm_prob( This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to ``top_k_sampling_from_probs``. """ - return _kernels.top_k_renorm_prob(probs, *_to_tensor_scalar_tuple(top_k), eps) + return _kernels.top_k_renorm_prob(probs, *_to_tensor_scalar_tuple(top_k)) def top_k_mask_logits( - logits: torch.Tensor, top_k: Union[torch.Tensor, int], eps: float = 1e-5 + logits: torch.Tensor, top_k: Union[torch.Tensor, int] ) -> torch.Tensor: r"""Fused GPU kernel for masking logits by top-k thresholding. @@ -572,8 +608,6 @@ def top_k_mask_logits( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k logits, set the rest to negative infinity. - eps: float - The epsilon value for numerical stability. Returns ------- @@ -584,7 +618,7 @@ def top_k_mask_logits( ---- The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_prob``. """ - return _kernels.top_k_mask_logits(logits, *_to_tensor_scalar_tuple(top_k), eps) + return _kernels.top_k_mask_logits(logits, *_to_tensor_scalar_tuple(top_k)) def chain_speculative_sampling( diff --git a/python/tests/test_sampling.py b/python/tests/test_sampling.py index 64344029..b6738d46 100644 --- a/python/tests/test_sampling.py +++ b/python/tests/test_sampling.py @@ -196,7 +196,9 @@ def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size p, filter_apply_order="top_k_first", ) - assert torch.all(samples == samples_ref) + assert torch.all( + samples == samples_ref + ), f"{samples} != {samples_ref}, {success}, {success_ref}" assert torch.all(success) assert torch.all(success_ref) @@ -231,7 +233,6 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_renorm_prob(batch_size, vocab_size, p): - eps = 1e-6 pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, indices = torch.sort(normalized_prob, descending=False) @@ -244,7 +245,7 @@ def test_top_p_renorm_prob(batch_size, vocab_size, p): dim=-1, keepdim=True ) - renorm_prob = flashinfer.sampling.top_p_renorm_prob(normalized_prob, p, eps=eps) + renorm_prob = flashinfer.sampling.top_p_renorm_prob(normalized_prob, p) numpy.testing.assert_allclose( renorm_prob_ground_truth.cpu().numpy(), renorm_prob.cpu().numpy(), @@ -291,7 +292,7 @@ def test_top_k_mask_logits(batch_size, vocab_size, k): probs = torch.softmax(logits, dim=-1) masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k) renormed_probs = torch.softmax(masked_logits, dim=-1) - renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k, 1e-8) + renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k) numpy.testing.assert_allclose( renormed_probs.cpu().numpy(),