From d7b49c7540063e3c9c1dda9f4cb5031cd7e98184 Mon Sep 17 00:00:00 2001 From: Zihao Ye <expye@outlook.com> Date: Tue, 29 Oct 2024 21:47:36 +0000 Subject: [PATCH] upd --- flashinfer-aot/csrc_aot/flashinfer_ops.cu | 39 ++-- include/flashinfer/pos_enc.cuh | 243 +++++++++++++--------- python/csrc/flashinfer_rope_ops.cu | 30 ++- python/csrc/rope.cu | 132 ++++++------ python/flashinfer/__init__.py | 2 + python/flashinfer/rope.py | 153 +++++++++++++- tests/test_rope.py | 133 ++++++++++-- 7 files changed, 521 insertions(+), 211 deletions(-) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index 05b259f5..9ab9a86c 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr, unsigned int top_k_val); -torch::Tensor chain_speculative_sampling( - torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, bool deterministic); +torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, + torch::Tensor uniform_samples, torch::Tensor target_probs, + torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, + bool deterministic); void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); @@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); - -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, +std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, @@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); m.def("apply_rope", &apply_rope, "Apply RoPE"); m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index d6f96e4c..ed0b732a 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -17,6 +17,7 @@ #define FLASHINFER_POS_ENC_CUH_ #include <cmath> +#include <cstdint> #include <string> #include "layout.cuh" @@ -94,6 +95,25 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope( return vec; } +template <uint32_t vec_size, uint32_t bdx, typename T> +__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin( + const T* x, const vec_t<float, vec_size>& cos, const vec_t<float, vec_size>& sin) { + constexpr uint32_t head_dim = vec_size * bdx; + vec_t<float, vec_size> permuted_vec, vec; + vec.cast_load(x + threadIdx.x * vec_size); + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < head_dim / 2) + ? threadIdx.x * vec_size + head_dim / 2 + : threadIdx.x * vec_size - head_dim / 2)); + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = + vec[i] * cos[i] + + ((threadIdx.x * vec_size < head_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i]; + } + return vec; +} + /*! * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave, * return thread-local vector. @@ -122,13 +142,28 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_interleav return vec; } +template <uint32_t vec_size, uint32_t bdx, typename T> +__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_interleave( + const T* x, const vec_t<float, vec_size>& cos, const vec_t<float, vec_size>& sin) { + vec_t<float, vec_size> vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + vec_before = vec; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i]; + } + return vec; +} + template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType> -__global__ void BatchQKApplyRotaryInPlaceKernel( - DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a, - float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { +__global__ void BatchQKApplyRotaryPosIdsKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, + size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, float smooth_b, + float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t<float, vec_size> freq; @@ -146,61 +181,56 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } - if (bx < batch_size * num_qo_heads) { - // apply rotary to q - const uint32_t batch_idx = bx / num_qo_heads; - const uint32_t qo_head_idx = bx % num_qo_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t<float, vec_size> cos, sin; + + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(pos) * freq[i]; + __sincosf(embed, &sin[i], &cos[i]); + } + +#pragma unroll 1 + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); vec_t<float, vec_size> q_vec; - if (i * bdy + ty < seq_len) { - DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - q_stride_n, q_stride_h); - if constexpr (interleave) { - q_vec = - vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty); - } else { - q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty); - } - q_vec.cast_store(q_ptr + tx * vec_size); + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin); + } else { + q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin); } + q_vec.cast_store(q_rope_ptr + tx * vec_size); } - } else { - // apply rotary to k - uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; - uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + +#pragma unroll 1 + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); vec_t<float, vec_size> k_vec; - if (i * bdy + ty < seq_len) { - DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - k_stride_n, k_stride_h); - if constexpr (interleave) { - k_vec = - vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty); - } else { - k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty); - } - k_vec.cast_store(k_ptr + tx * vec_size); + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin); + } else { + k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin); } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } } } template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType> -__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - float smooth_a, float smooth_b, float rope_rcp_scale, - float rope_rcp_theta) { +__global__ void BatchQKApplyRotaryKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + float smooth_a, float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t<float, vec_size> freq; @@ -232,8 +262,7 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric q_stride_n, q_stride_h); DType* q_rope_ptr = q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - /*q_stride_n=*/num_qo_heads * head_dim, - /*q_stride_h=*/head_dim); + q_rope_stride_n, q_rope_stride_h); if constexpr (interleave) { q_vec = vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty); @@ -257,8 +286,7 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric k_stride_n, k_stride_h); DType* k_rope_ptr = k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - /*kv_stride_n=*/num_kv_heads * head_dim, - /*kv_stride_h=*/head_dim); + k_rope_stride_n, k_rope_stride_h); if constexpr (interleave) { k_vec = vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty); @@ -281,13 +309,14 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric } template <typename DType, typename IdType> -cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, - cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope, + IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; @@ -299,21 +328,26 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ constexpr uint32_t bdx = HEAD_DIM / vec_size; uint32_t num_threads = std::max(128U, bdx); uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nblks((nnz + bdy - 1) / bdy); dim3 nthrs(bdx, bdy); auto kernel = - BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; + BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; void* args[] = {(void*)&q, (void*)&k, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, + (void*)&q_rope, + (void*)&k_rope, + (void*)&pos_ids, + (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -326,16 +360,18 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ } template <typename DType, typename IdType> -cudaError_t BatchQKApplyLlama31RotaryInPlace( - DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); - float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); + float smooth_a = 0.f; + float smooth_b = 0.f; DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -345,10 +381,11 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( uint32_t bdy = num_threads / bdx; dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); dim3 nthrs(bdx, bdy); - auto kernel = - BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; + auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; void* args[] = {(void*)&q, (void*)&k, + (void*)&q_rope, + (void*)&k_rope, (void*)&indptr, (void*)&offsets, (void*)&batch_size, @@ -358,6 +395,10 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -370,17 +411,17 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( } template <typename DType, typename IdType> -cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, bool interleave, - float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyLlama31Rotary( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = 0.f; - float smooth_b = 0.f; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -404,6 +445,10 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -416,15 +461,13 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, } template <typename DType, typename IdType> -cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyLlama31RotaryPosIds( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); @@ -436,22 +479,26 @@ cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ constexpr uint32_t bdx = HEAD_DIM / vec_size; uint32_t num_threads = std::max(128U, bdx); uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nblks((nnz + bdy - 1) / bdy); dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; + auto kernel = + BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; void* args[] = {(void*)&q, (void*)&k, (void*)&q_rope, (void*)&k_rope, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, + (void*)&pos_ids, + (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index 4075930b..ef046ead 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -15,28 +15,36 @@ */ #include <torch/extension.h> -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); +#include <vector> -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, +std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); m.def("apply_rope", &apply_rope, "Apply RoPE"); m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); } diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index bb8d5a19..d2ca9155 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -19,9 +19,10 @@ using namespace flashinfer; -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -44,68 +45,80 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); indptr = indptr.to(torch::kInt32); offsets = offsets.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyRotaryInPlace( + cudaError_t status = BatchQKApplyRotary( static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()), + static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()), static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + + k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); + + return {q_rope, k_rope}; } -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length) { +std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous - CHECK_INPUT(indptr); - CHECK_INPUT(offsets); + CHECK_INPUT(pos_ids); auto device = q.device(); CHECK_EQ(k.device(), device); - CHECK_DIM(3, q); // q: (nnz, H_Q, D) - CHECK_DIM(3, k); // k: (nnz, H_K, D) - CHECK_DIM(1, indptr); // indptr: (B + 1) - CHECK_DIM(1, offsets); // offsets: (B) + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_EQ(q.size(0), k.size(0)); CHECK_EQ(q.size(2), k.size(2)); unsigned int num_qo_heads = q.size(1); unsigned int num_kv_heads = k.size(1); unsigned int head_dim = q.size(2); - unsigned int batch_size = offsets.size(0); - CHECK_EQ(indptr.size(0), batch_size + 1); + unsigned int nnz = q.size(0); size_t q_stride_n = q.stride(0); size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + pos_ids = pos_ids.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyLlama31RotaryInPlace( + cudaError_t status = BatchQKApplyRotaryPosIds( static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()), - static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, - old_context_length, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + + static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()), + static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; }); + + return {q_rope, k_rope}; } -std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -128,21 +141,24 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); indptr = indptr.to(torch::kInt32); offsets = offsets.to(torch::kInt32); - // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. - auto q_rope = torch::empty_like(q); - auto k_rope = torch::empty_like(k); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyRotary( + cudaError_t status = BatchQKApplyLlama31Rotary( static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()), static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + + k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, + rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); @@ -150,50 +166,46 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T return {q_rope, k_rope}; } -std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous - CHECK_INPUT(indptr); - CHECK_INPUT(offsets); + CHECK_INPUT(pos_ids); auto device = q.device(); CHECK_EQ(k.device(), device); - CHECK_DIM(3, q); // q: (nnz, H_Q, D) - CHECK_DIM(3, k); // k: (nnz, H_K, D) - CHECK_DIM(1, indptr); // indptr: (B + 1) - CHECK_DIM(1, offsets); // offsets: (B) + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_EQ(q.size(0), k.size(0)); CHECK_EQ(q.size(2), k.size(2)); unsigned int num_qo_heads = q.size(1); unsigned int num_kv_heads = k.size(1); unsigned int head_dim = q.size(2); - unsigned int batch_size = offsets.size(0); - CHECK_EQ(indptr.size(0), batch_size + 1); + unsigned int nnz = q.size(0); size_t q_stride_n = q.stride(0); size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); - - // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. - auto q_rope = torch::empty_like(q); - auto k_rope = torch::empty_like(k); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + pos_ids = pos_ids.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyLlama31Rotary( + cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()), - static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, - old_context_length, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + + static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, + high_freq_factor, old_context_length, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; }); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index cb023a67..724fc3f0 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -60,6 +60,8 @@ apply_llama31_rope_inplace as apply_llama31_rope_inplace, apply_rope as apply_rope, apply_rope_inplace as apply_rope_inplace, + apply_rope_pos_ids as apply_rope_pos_ids, + apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace, ) from .sampling import ( chain_speculative_sampling as chain_speculative_sampling, diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 408c1f4a..29c2fcb7 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -118,8 +118,8 @@ def apply_rope_inplace( -------- apply_rope """ - return get_rope_module().apply_rope_inplace( - q, k, indptr, offsets, interleave, rope_scale, rope_theta + get_rope_module().apply_rope( + q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -136,6 +136,70 @@ def _fake_apply_rope_inplace( pass +@register_custom_op("flashinfer::apply_rope_pos_ids_inplace", mutates_args=("q", "k")) +def apply_rope_pos_ids_inplace( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial <ragged-layout>` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + + See Also + -------- + apply_rope_pos_ids + """ + get_rope_module().apply_rope_pos_ids( + q, k, q, k, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids_inplace") +def _fake_apply_rope_pos_ids_inplace( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + pass + + @register_custom_op("flashinfer::apply_llama31_rope_inplace", mutates_args=("q", "k")) def apply_llama31_rope_inplace( q: torch.Tensor, @@ -222,7 +286,9 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ - return get_rope_module().apply_llama31_rope_inplace( + get_rope_module().apply_llama31_rope( + q, + k, q, k, indptr, @@ -339,8 +405,10 @@ def apply_rope( -------- apply_rope_inplace """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) return get_rope_module().apply_rope( - q, k, indptr, offsets, interleave, rope_scale, rope_theta + q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -357,6 +425,79 @@ def _fake_apply_rope( return torch.empty_like(q), torch.empty_like(k) +@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=()) +def apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial <ragged-layout>` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(batch_size + 1)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + + See Also + -------- + apply_rope_inplace + """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) + return get_rope_module().apply_rope_pos_ids( + q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids") +def _fake_apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) + + @register_custom_op("flashinfer::apply_llama31_rope", mutates_args=()) def apply_llama31_rope( q: torch.Tensor, @@ -454,9 +595,13 @@ def apply_llama31_rope( -------- apply_llama31_rope_inplace """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) return get_rope_module().apply_llama31_rope( q, k, + q_rope, + k_rope, indptr, offsets, interleave, diff --git a/tests/test_rope.py b/tests/test_rope.py index 1750ed34..f7ee84c9 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -69,12 +69,8 @@ def test_llama_rope_inplace( ) # compare - torch.testing.assert_close( - q_rope_ref, q, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -125,12 +121,111 @@ def test_llama_rope( ) # compare - torch.testing.assert_close( - q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 + torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope_pos_ids( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", ) - torch.testing.assert_close( - k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + pos_ids = torch.cat( + [ + torch.arange(offset, qkv_len + offset, dtype=torch.int32) + for _ in range(batch_size) + ] + ).to("cuda:0") + + q_rope, k_rope = flashinfer.apply_rope( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( + q, k, pos_ids, interleave=True, rope_theta=1e4 + ) + + # compare + torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_pos_ids, k_rope, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope_pos_ids_inplace( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + pos_ids = torch.cat( + [ + torch.arange(offset, qkv_len + offset, dtype=torch.int32) + for _ in range(batch_size) + ] + ).to("cuda:0") + + q_clone = q.clone() + k_clone = k.clone() + + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + flashinfer.apply_rope_pos_ids_inplace( + q_clone, k_clone, pos_ids, interleave=True, rope_theta=1e4 + ) + + # compare + torch.testing.assert_close(q_clone, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_clone, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -181,12 +276,8 @@ def test_llama31_rope_inplace( ) # compare - torch.testing.assert_close( - q_rope_ref, q, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -237,12 +328,8 @@ def test_llama31_rope( ) # compare - torch.testing.assert_close( - q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) if __name__ == "__main__": @@ -250,3 +337,5 @@ def test_llama31_rope( test_llama31_rope_inplace(1, 1, 8, 8, 0, 128) test_llama_rope(2, 1, 8, 8, 1, 128) test_llama31_rope(1, 1, 8, 8, 0, 128) + test_llama_rope_pos_ids(2, 1, 8, 8, 1, 128) + test_llama_rope_pos_ids_inplace(2, 1, 8, 8, 1, 128)