Skip to content

Commit

Permalink
fix broken tvm integration caused by #568
Browse files Browse the repository at this point in the history
  • Loading branch information
tsu-bin committed Nov 5, 2024
1 parent c3572de commit a9e9667
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,21 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope,
return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
cudaStream_t stream = nullptr) {
return BatchQKApplyRotary<DType, IdType>(q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads,
head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h,
q_stride_n, q_stride_h, k_stride_n, k_stride_h,
interleave, rope_scale, rope_theta, stream);

}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31Rotary(
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,
Expand Down

0 comments on commit a9e9667

Please sign in to comment.