Skip to content

feat: no-cache attention in PyTorch workflow #3085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2317,24 +2317,26 @@ int AttentionOp::initialize() noexcept
// TODO(yibinl): remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
fmhaParams.forceFp32Acc = false;

// setting attention mask type based on the mask type
fmhaParams.setAttentionMaskType(static_cast<std::int8_t>(mMaskType));

if (isCrossAttention())
{
fmhaParams.attentionMaskType = ContextAttentionMaskType::PADDING;
// always use paged-kv-fmha if paged_kv cache is used.
fmhaParams.attentionInputLayout
= mPagedKVCache ? AttentionInputLayout::Q_PAGED_KV : AttentionInputLayout::Q_CONTIGUOUS_KV;
}
else if (!useKVCache())
{
fmhaParams.attentionInputLayout = AttentionInputLayout::PACKED_QKV;
}
else
{
fmhaParams.attentionMaskType = ContextAttentionMaskType::CAUSAL;
fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA && !mIsMLAEnabled)
? AttentionInputLayout::Q_PAGED_KV
: AttentionInputLayout::PACKED_QKV;
}
if (useCustomMask())
{
fmhaParams.attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
}
fmhaParams.isSPadded = !mRemovePadding;
fmhaParams.numQHeads = mNumAttnHeads;
fmhaParams.numKvHeads = mNumAttnKVHeads;
Expand Down Expand Up @@ -2448,7 +2450,7 @@ int AttentionOp::initialize() noexcept
}

mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled) && !mCrossAttention
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16);
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16) && mUseKVCache;

if (mEnableXQA)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,44 @@ struct MHARunnerFixedParams

return output;
}

/**
* Set attention mask type from AttentionMaskType enum
* @param maskType The AttentionMaskType to use
* @return Reference to this object for method chaining
* @throws If the maskType cannot be mapped to ContextAttentionMaskType
*/
MHARunnerFixedParams& setAttentionMaskType(std::int8_t maskType)
{
switch (maskType)
{
case 0: // tensorrt_llm::kernels::AttentionMaskType::PADDING
attentionMaskType = ContextAttentionMaskType::PADDING;
break;
case 1: // tensorrt_llm::kernels::AttentionMaskType::CAUSAL
attentionMaskType = ContextAttentionMaskType::CAUSAL;
break;
case 2: // tensorrt_llm::kernels::AttentionMaskType::SLIDING_WINDOW_CAUSAL
attentionMaskType = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
break;
// NOTE: For BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE context phase, CAUSAL mask is used
case 3: // tensorrt_llm::kernels::AttentionMaskType::BIDIRECTIONAL
attentionMaskType = ContextAttentionMaskType::CAUSAL;
break;
case 4: // tensorrt_llm::kernels::AttentionMaskType::BIDIRECTIONALGLM
attentionMaskType = ContextAttentionMaskType::CAUSAL;
break;
case 5: // tensorrt_llm::kernels::AttentionMaskType::BLOCKSPARSE
attentionMaskType = ContextAttentionMaskType::CAUSAL;
break;
case 6: // tensorrt_llm::kernels::AttentionMaskType::CUSTOM_MASK
attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
break;
default:
TLLM_THROW("AttentionMaskType %d cannot be mapped to ContextAttentionMaskType", static_cast<int>(maskType));
}
return *this;
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bool FmhaDispatcher::isSupported()
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
tllmRunnerParams.mMultiCtasKvMode = false;
Expand Down Expand Up @@ -143,7 +143,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)

// Parameters to select kernels.
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
// Always use persistent scheduler for better performance.
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
Expand Down
26 changes: 26 additions & 0 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,32 @@ struct TllmGenFmhaRunnerParams
float mScaleSfKv;
// The cuda stream.
cudaStream_t stream;

// set the attention mask type
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)
{
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
switch (maskType)
{
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
mMaskType = TrtllmGenAttentionMaskType::Dense;
break;
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
mMaskType = TrtllmGenAttentionMaskType::Causal;
break;
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL
mMaskType = TrtllmGenAttentionMaskType::SlidingWindowCausal;
break;
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
mMaskType = TrtllmGenAttentionMaskType::Custom;
break;
default:
TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
static_cast<int>(maskType));
}
return *this;
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
109 changes: 63 additions & 46 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ class RunnerBase
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::Tensor qkv, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
torch::Tensor kv_cache_block_offsets, torch::Tensor host_kv_cache_block_offsets,
torch::Tensor host_kv_cache_pool_pointers, torch::Tensor host_kv_cache_pool_mapping,
torch::optional<torch::Tensor> cache_indirection, torch::optional<torch::Tensor> kv_scale_orig_quant,
torch::optional<torch::Tensor> kv_scale_quant_orig, torch::optional<torch::Tensor> out_scale,
torch::optional<torch::Tensor> rotary_inv_freq, torch::optional<torch::Tensor> rotary_cos_sin,
torch::optional<torch::Tensor> latent_cache, torch::optional<torch::Tensor> q_pe,
torch::optional<torch::Tensor> block_ids_per_seq, torch::optional<torch::Tensor> mrope_rotary_cos_sin,
torch::optional<torch::Tensor> mrope_position_deltas) const
torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas) const
= 0;
};

Expand Down Expand Up @@ -114,13 +115,15 @@ class Runner : public RunnerBase
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::Tensor qkv, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
torch::Tensor kv_cache_block_offsets, torch::Tensor host_kv_cache_block_offsets,
torch::Tensor host_kv_cache_pool_pointers, torch::Tensor host_kv_cache_pool_mapping,
torch::optional<torch::Tensor> cache_indirection, torch::optional<torch::Tensor> kv_scale_orig_quant,
torch::optional<torch::Tensor> kv_scale_quant_orig, torch::optional<torch::Tensor> out_scale,
torch::optional<torch::Tensor> rotary_inv_freq, torch::optional<torch::Tensor> rotary_cos_sin,
torch::optional<torch::Tensor> latent_cache, torch::optional<torch::Tensor> q_pe,
torch::optional<torch::Tensor> block_ids_per_seq, torch::optional<torch::Tensor> mrope_rotary_cos_sin,
torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin,
torch::optional<torch::Tensor> mrope_position_deltas) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
Expand Down Expand Up @@ -187,26 +190,32 @@ class Runner : public RunnerBase
int const cyclic_attention_window_size = attention_window_size;
bool const can_use_one_more_block = beam_width > 1;

int max_blocks_per_sequence = kv_cache_block_offsets.size(-1);
int32_t const pool_index = host_kv_cache_pool_mapping.index({op.mLayerIdx, 0}).item<int32_t>();
int32_t const layer_idx_in_cache_pool = host_kv_cache_pool_mapping.index({op.mLayerIdx, 1}).item<int32_t>();
KVBlockArray::DataType* block_offsets
= static_cast<KVBlockArray::DataType*>(kv_cache_block_offsets.index({pool_index, seq_offset}).data_ptr());
int max_blocks_per_sequence = op.useKVCache() ? kv_cache_block_offsets.value().size(-1) : 0;
int32_t const pool_index
= op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>() : 0;
int32_t const layer_idx_in_cache_pool
= op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>() : 0;
KVBlockArray::DataType* block_offsets = static_cast<KVBlockArray::DataType*>(
op.useKVCache() ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr);
KVBlockArray::DataType* host_block_offsets = static_cast<KVBlockArray::DataType*>(
host_kv_cache_block_offsets.index({pool_index, seq_offset}).data_ptr());
op.useKVCache() ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr);

auto const cache_elem_size = (op.mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize;
auto const bytes_per_block = block_size * cache_elem_size;
int32_t const kv_factor = op.isMLAEnabled() ? 1 : 2;
auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block;

void* host_primary_pool_pointer = reinterpret_cast<void*>(
reinterpret_cast<char*>(host_kv_cache_pool_pointers.index({pool_index, 0}).item<int64_t>())
+ intra_pool_offset);
void* host_secondary_pool_pointer = reinterpret_cast<void*>(
reinterpret_cast<char*>(host_kv_cache_pool_pointers.index({pool_index, 1}).item<int64_t>())
+ intra_pool_offset);
void* host_primary_pool_pointer = op.useKVCache()
? reinterpret_cast<void*>(
reinterpret_cast<char*>(host_kv_cache_pool_pointers.value().index({pool_index, 0}).item<int64_t>())
+ intra_pool_offset)
: nullptr;
void* host_secondary_pool_pointer = op.useKVCache()
? reinterpret_cast<void*>(
reinterpret_cast<char*>(host_kv_cache_pool_pointers.value().index({pool_index, 1}).item<int64_t>())
+ intra_pool_offset)
: nullptr;

float const* kv_scale_orig_quant_ptr = nullptr;
float const* kv_scale_quant_orig_ptr = nullptr;
Expand Down Expand Up @@ -330,27 +339,32 @@ using torch_ext::trtllm::attention::AttentionInputType;
torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
std::optional<torch::ScalarType> out_dtype, torch::optional<torch::Tensor> workspace_,
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types, torch::Tensor kv_cache_block_offsets,
torch::Tensor host_kv_cache_block_offsets, torch::Tensor host_kv_cache_pool_pointers,
torch::Tensor host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
torch::optional<torch::Tensor> kv_cache_block_offsets, torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq, bool const is_fused_qkv,
bool const update_kv_cache, int64_t const predicted_tokens_per_seq, int64_t const layer_idx,
int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, int64_t const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, double const rotary_embedding_scale,
double const rotary_embedding_short_m_scale, double const rotary_embedding_long_m_scale,
int64_t const rotary_embedding_max_positions, int64_t const rotary_embedding_original_max_positions,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
double const rotary_embedding_scale, double const rotary_embedding_short_m_scale,
double const rotary_embedding_long_m_scale, int64_t const rotary_embedding_max_positions,
int64_t const rotary_embedding_original_max_positions, bool const use_paged_context_fmha,
std::optional<int64_t> attention_input_type, bool is_mla_enable, std::optional<int64_t> q_lora_rank,
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();

TLLM_CHECK_WITH_INFO(is_fused_qkv, "Only fused QKV is supported now");
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
Expand Down Expand Up @@ -416,7 +430,9 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
op->mHeadSize = head_size;
op->mMaskType = static_cast<tensorrt_llm::kernels::AttentionMaskType>(int32_t(mask_type));
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
op->mTokensPerBlock = tokens_per_block;
op->mUseKVCache = use_kv_cache;
op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache
op->mTokensPerBlock = tokens_per_block.value_or(0);
op->mMaxContextLength = max_context_length;
op->mQScaling = q_scaling;
op->mPositionEmbeddingType
Expand All @@ -434,7 +450,8 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch

if (is_mla_enable)
{
int32_t const layer_num = host_kv_cache_pool_mapping.size(0);
TLLM_CHECK(host_kv_cache_pool_mapping.has_value());
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);
op->mIsMLAEnabled = true;
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64
op->mUseFlashMLA = tensorrt_llm::common::getSMVersion() == 90 && head_size == 576 && tokens_per_block == 64;
Expand Down Expand Up @@ -572,10 +589,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", Tensor context_lengths"
", Tensor host_context_lengths"
", Tensor host_request_types"
", Tensor kv_cache_block_offsets"
", Tensor host_kv_cache_block_offsets"
", Tensor host_kv_cache_pool_pointers"
", Tensor host_kv_cache_pool_mapping"
", Tensor? kv_cache_block_offsets"
", Tensor? host_kv_cache_block_offsets"
", Tensor? host_kv_cache_pool_pointers"
", Tensor? host_kv_cache_pool_mapping"
", Tensor? cache_indirection"
", Tensor? kv_scale_orig_quant"
", Tensor? kv_scale_quant_orig"
Expand All @@ -592,7 +609,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", int num_heads"
", int num_kv_heads"
", int head_size"
", int tokens_per_block"
", int? tokens_per_block"
", int max_num_requests"
", int max_context_length"
", int attention_window_size"
Expand Down
Loading