Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: attention kernel refactoring #339

Merged
merged 13 commits into from
Oct 23, 2024
923 changes: 525 additions & 398 deletions src/kernels/attention/flash_infer/attention_kernel.h

Large diffs are not rendered by default.

39 changes: 16 additions & 23 deletions src/kernels/attention/flash_infer/attention_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,18 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
float sm_scale,
float* alibi_slopes,
cudaStream_t stream) {
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
IdType *request_indices = nullptr, *qo_tile_indices = nullptr,
*kv_tile_indices = nullptr, *o_indptr = nullptr,
*merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
bool* block_valid_mask = nullptr;
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS();
IdType* request_indices = handler->GetRequestIndices<IdType>();
IdType* qo_tile_indices = handler->GetQOTileIndices<IdType>();
IdType* kv_tile_indices = handler->GetKVTileIndices<IdType>();
bool* block_valid_mask = handler->GetBlockValidMask();
IdType* o_indptr = handler->GetOIndptr<IdType>();
IdType* merge_indptr = handler->GetMergeIndptr<IdType>();
IdType* kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
WarpLayout warp_layout = handler->GetWarpLayout();
uint32_t padded_batch_size = handler->GetPaddedBatchSize();
uint32_t total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return mha_varlen_dispatch<WARP_LAYOUT,
Expand Down Expand Up @@ -145,7 +136,8 @@ void BatchPrefillWrapper::Plan(torch::Tensor float_workspace_buffer,
unsigned int num_kv_heads,
unsigned int head_dim,
unsigned int page_size,
torch::Tensor empty_q_data) {
torch::Tensor empty_q_data,
int32_t num_sm) {
CHECK_INPUT(float_workspace_buffer);
CHECK_INPUT(int_workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
Expand Down Expand Up @@ -182,7 +174,8 @@ void BatchPrefillWrapper::Plan(torch::Tensor float_workspace_buffer,
num_qo_heads,
num_kv_heads,
head_dim,
page_size);
page_size,
num_sm);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
Expand Down
9 changes: 4 additions & 5 deletions src/kernels/attention/flash_infer/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

#include <torch/torch.h>

#include <flashinfer/attention/warp_layout.cuh>

#include "handler.h"

namespace flashinfer {

class BatchPrefillWrapper {
public:
BatchPrefillWrapper(bool enable_cuda_graph)
: handler_(std::make_shared<flashinfer::BatchPrefillHandler>(
: handler_(std::make_unique<flashinfer::BatchPrefillHandler>(
enable_cuda_graph)) {}

void Plan(torch::Tensor float_workspace_buffer,
Expand All @@ -24,7 +22,8 @@ class BatchPrefillWrapper {
unsigned int num_kv_heads,
unsigned int head_dim,
unsigned page_size,
torch::Tensor empty_q_data);
torch::Tensor empty_q_data,
int32_t num_sm);

bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }

Expand All @@ -43,7 +42,7 @@ class BatchPrefillWrapper {
std::optional<torch::Tensor> alibi_slopes);

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
std::unique_ptr<flashinfer::BatchPrefillHandler> handler_;
};

} // namespace flashinfer
Loading