From a9e9667d20598faec5831bcf6cf90f254cc45219 Mon Sep 17 00:00:00 2001 From: tsu-bin Date: Tue, 5 Nov 2024 12:57:28 +0800 Subject: [PATCH] fix broken tvm integration caused by #568 --- include/flashinfer/pos_enc.cuh | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index ed0b732a..442d03e3 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -410,6 +410,21 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, return cudaSuccess; } +template +cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, + bool interleave, float rope_scale, float rope_theta, + cudaStream_t stream = nullptr) { + return BatchQKApplyRotary(q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, + head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, + interleave, rope_scale, rope_theta, stream); + +} + template cudaError_t BatchQKApplyLlama31Rotary( DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,