Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support cached cos/sin in rope APIs #585

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/bench_append_paged_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import dataclasses
from typing import cast

import flashinfer
import torch
from triton.testing import do_bench

import flashinfer


@dataclasses.dataclass(kw_only=True)
class ModelConfig:
Expand Down
92 changes: 92 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cmath>
#include <cstdint>
#include <iostream>
#include <string>

#include "layout.cuh"
Expand Down Expand Up @@ -156,6 +157,55 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
return vec;
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache,
float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h,
size_t k_rope_stride_n, size_t k_rope_stride_h) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;

vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];

cos.load(cos_cache + pos * head_dim + tx * vec_size);
sin.load(sin_cache + pos * head_dim + tx * vec_size);

#pragma unroll 1
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr =
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}

#pragma unroll 1
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr =
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
}
}
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsKernel(
Expand Down Expand Up @@ -308,6 +358,48 @@ __global__ void BatchQKApplyRotaryKernel(
__VA_ARGS__ \
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache,
IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
bool interleave, cudaStream_t stream = nullptr) {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks((nnz + bdy - 1) / bdy);
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&cos_cache,
(void*)&sin_cache,
(void*)&pos_ids,
(void*)&nnz,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&q_rope_stride_n,
(void*)&q_rope_stride_h,
(void*)&k_rope_stride_n,
(void*)&k_rope_stride_h};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope,
IdType* __restrict__ pos_ids, uint32_t nnz,
Expand Down
7 changes: 7 additions & 0 deletions python/csrc/flashinfer_rope_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor cos_cache,
torch::Tensor sin_cache, torch::Tensor pos_ids,
bool interleave);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
"Apply Llama 3.1 style RoPE with positional ids");
m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache,
"Apply RoPE with positional ids and cosine/sine cache");
}
62 changes: 58 additions & 4 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using namespace flashinfer;
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

Expand Down Expand Up @@ -69,8 +69,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(pos_ids);

auto device = q.device();
Expand Down Expand Up @@ -107,6 +107,60 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
});
}

void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor cos_cache,
torch::Tensor sin_cache, torch::Tensor pos_ids,
bool interleave) {
CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k);
CHECK_INPUT(cos_cache);
CHECK_INPUT(sin_cache);
CHECK_INPUT(pos_ids);
auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_EQ(cos_cache.device(), device);
CHECK_EQ(sin_cache.device(), device);
CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(2, cos_cache); // cos_cache: (max_seq_len, D)
CHECK_DIM(2, sin_cache); // sin_cache: (max_seq_len, D)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
CHECK_EQ(cos_cache.size(1), q.size(2));
CHECK_EQ(sin_cache.size(1), q.size(2));
CHECK_EQ(cos_cache.dtype(), torch::kFloat32);
CHECK_EQ(sin_cache.dtype(), torch::kFloat32);
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int nnz = q.size(0);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_cache.data_ptr()), static_cast<float*>(sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim,
q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h,
k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}

void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
Expand Down
8 changes: 8 additions & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@
from .quantization import segment_packbits as segment_packbits
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_probs as sampling_from_probs
Expand Down
Loading