Skip to content

Commit

Permalink
misc: improve error handling of sampling kernels (#456)
Browse files Browse the repository at this point in the history
Add an option to check whether there are nan inputs.
This PR also removes all `eps` arguments in renorm kernels: previously
we pre-set a eps constant to determine when to stop the binary search,
however, this might not be accuracy when vocabulary size grows (e.g. >=
1e6 in llama3 where our eps might be set to 1e-5).
In this PR, we implement a loop variant which do not rely on any
external eps, and it can help us address some of the issues such as
vllm-project/vllm#7137 (comment) .
  • Loading branch information
yzh119 authored Aug 20, 2024
1 parent 0d61871 commit 0dce178
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 63 deletions.
134 changes: 104 additions & 30 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#ifndef FLASHINFER_SAMPLING_CUH_
#define FLASHINFER_SAMPLING_CUH_

#include <driver_types.h>

#include <cub/block/block_adjacent_difference.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
Expand Down Expand Up @@ -347,13 +345,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (temp_storage.data.block_aggregate.pair.count >= k) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
Expand Down Expand Up @@ -433,13 +431,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (float(q) >= top_p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
Expand Down Expand Up @@ -539,13 +537,13 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (pivot < scaled_p) {
// failed to sample within MAX_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
Expand Down Expand Up @@ -627,13 +625,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
Expand Down Expand Up @@ -808,7 +806,7 @@ struct RenormTempStorage {
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType>
__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* top_p_arr,
float top_p_val, float eps, uint32_t d) {
float top_p_val, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
Expand Down Expand Up @@ -844,12 +842,20 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
threadlocal_max_val = temp_storage.data.max_val;

float low = 0, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
DType sum_low(1);
// f(x) = probs[probs > x], f(x) is non-increasing
// loop invariant: f(low) >= p, f(high) < p
while (high - low > eps) {
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= p, f(high) < p
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
do {
DType threadlocal_sum(0);
float mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -858,26 +864,42 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
threadlocal_sum +=
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_greater_than_pivot);
__syncthreads();
}
min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.data.block_aggregate.value = threadlocal_sum;
temp_storage.data.min_val = min_gt_low;
temp_storage.data.max_val = max_le_high;
}
__syncthreads();
threadlocal_sum = temp_storage.data.block_aggregate.value;
min_gt_low = temp_storage.data.min_val;
max_le_high = temp_storage.data.max_val;
if (threadlocal_sum >= p) {
low = mid;
sum_low = float(threadlocal_sum);
} else {
high = mid;
high = min(mid, max_le_high);
}
}
} while (min_gt_low != max_le_high);

DType normalizer = math::ptx_rcp(max(sum_low, eps));
DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));

// normalize
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
Expand All @@ -898,7 +920,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr,
uint32_t top_k_val, float eps, uint32_t d) {
uint32_t top_k_val, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
Expand Down Expand Up @@ -941,12 +963,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
threadlocal_min_val = temp_storage.data.min_val;

float low = threadlocal_min_val - 1, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// loop invariant: f(low) >= k, f(high) < k
while (high - low > eps) {
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
int threadlocal_count_sum = 0;
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
float mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
logits_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -956,23 +986,41 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_greater_than_pivot_count[j] =
logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, logits_vec[j]);
}
if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, logits_vec[j]);
}
}
threadlocal_count_sum +=
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce_int)
.Sum<VEC_SIZE>(probs_greater_than_pivot_count);
__syncthreads();
}
min_gt_low =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
__syncthreads();
if (tx == 0) {
temp_storage.data.block_aggregate.count = threadlocal_count_sum;
temp_storage.data.min_val = min_gt_low;
temp_storage.data.max_val = max_le_high;
}
__syncthreads();
threadlocal_count_sum = temp_storage.data.block_aggregate.count;
min_gt_low = temp_storage.data.min_val;
max_le_high = temp_storage.data.max_val;
if (threadlocal_count_sum >= k) {
low = mid;
} else {
high = mid;
high = min(mid, max_le_high);
}
}
} while (min_gt_low != max_le_high);
pivot = low;
}

Expand All @@ -996,7 +1044,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t top_k_val, float eps, uint32_t d) {
uint32_t top_k_val, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
Expand Down Expand Up @@ -1033,13 +1081,21 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
threadlocal_max_val = temp_storage.data.max_val;

float low = 0, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
DType sum_low(1);
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// loop invariant: f(low) >= k, f(high) < k
while (high - low > eps) {
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
Pair<DType> threadlocal_sum{DType(0), 0};
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
float mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
Expand All @@ -1050,26 +1106,44 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
probs_greater_than_pivot_pair[j] = {
(probs_vec[j] > mid) ? probs_vec[j] : DType(0),
(probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
threadlocal_sum += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
__syncthreads();
}
min_gt_low =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
__syncthreads();
if (tx == 0) {
temp_storage.data.block_aggregate.pair = threadlocal_sum;
temp_storage.data.min_val = min_gt_low;
temp_storage.data.max_val = max_le_high;
}
__syncthreads();
threadlocal_sum = temp_storage.data.block_aggregate.pair;
min_gt_low = temp_storage.data.min_val;
max_le_high = temp_storage.data.max_val;
if (threadlocal_sum.count >= k) {
low = mid;
sum_low = float(threadlocal_sum.value);
} else {
high = mid;
high = min(mid, max_le_high);
}
}
} while (min_gt_low != max_le_high);

normalizer = math::ptx_rcp(max(sum_low, eps));
normalizer = math::ptx_rcp(max(sum_low, 1e-8));
pivot = low;
}

Expand All @@ -1090,7 +1164,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
}

template <typename DType>
cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, float eps,
cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
uint32_t batch_size, float top_p_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t BLOCK_THREADS = 1024;
Expand All @@ -1099,7 +1173,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, &d};
void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
FLASHINFER_CUDA_CALL(
Expand All @@ -1110,7 +1184,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
}

template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps,
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t BLOCK_THREADS = 1024;
Expand All @@ -1119,7 +1193,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, &d};
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
FLASHINFER_CUDA_CALL(
Expand All @@ -1130,7 +1204,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
}

template <typename DType, typename IdType>
cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, float eps,
cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t BLOCK_THREADS = 1024;
Expand All @@ -1139,7 +1213,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, &d};
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
FLASHINFER_CUDA_CALL(
Expand Down
6 changes: 3 additions & 3 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val, double eps);
double top_p_val);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, double eps);
unsigned int top_k_val);

torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, double eps);
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
Expand Down
Loading

0 comments on commit 0dce178

Please sign in to comment.