Skip to content

Commit

Permalink
bugfix: fix broken tvm integration caused by #568 (#582)
Browse files Browse the repository at this point in the history
Hi, now tvm wrapper build failed cause by #568.
I noticed that the new `BatchQKApplyRotary` interface removed
`__restrict__` modifier from `DType* q, DType* k, DType* q_rope, DType*
k_rope`, so it's trivial to just add one adapter function to fix this
issue.

Co-authored-by: tsu-bin <[email protected]>
  • Loading branch information
tsu-bin and tsu-bin authored Nov 5, 2024
1 parent 979bb6c commit 7557dc8
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ __global__ void BatchQKApplyRotaryPosIdsKernel(
size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h,
size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, float smooth_b,
float rope_rcp_scale, float rope_rcp_theta) {
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> freq;
Expand Down Expand Up @@ -410,6 +411,21 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope,
return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
cudaStream_t stream = nullptr) {
return BatchQKApplyRotary<DType, IdType>(q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads,
head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h,
q_stride_n, q_stride_h, k_stride_n, k_stride_h,
interleave, rope_scale, rope_theta, stream);

}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31Rotary(
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,
Expand Down

0 comments on commit 7557dc8

Please sign in to comment.