Skip to content

Commit f23e17a

Browse files
committed
refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment
remove Debug code. Signed-off-by: Qixiang Lin <[email protected]>
1 parent ff606ee commit f23e17a

File tree

2 files changed

+11
-23
lines changed

2 files changed

+11
-23
lines changed

Diff for: cpp/tensorrt_llm/common/attentionOp.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -2293,20 +2293,16 @@ int AttentionOp::initialize() noexcept
22932293

22942294
if (isCrossAttention())
22952295
{
2296-
// Temporary check for cross attention
2297-
TLLM_CHECK_DEBUG(mMaskType == tensorrt_llm::kernels::AttentionMaskType::PADDING);
22982296
// always use paged-kv-fmha if paged_kv cache is used.
22992297
fmhaParams.attentionInputLayout
23002298
= mPagedKVCache ? AttentionInputLayout::Q_PAGED_KV : AttentionInputLayout::Q_CONTIGUOUS_KV;
23012299
}
2302-
else if (!mUseKVCache)
2300+
else if (!useKVCache())
23032301
{
23042302
fmhaParams.attentionInputLayout = AttentionInputLayout::PACKED_QKV;
23052303
}
23062304
else
23072305
{
2308-
// Temporary check for other attention types
2309-
TLLM_CHECK_DEBUG(mMaskType == tensorrt_llm::kernels::AttentionMaskType::CAUSAL);
23102306
fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA && !mIsMLAEnabled)
23112307
? AttentionInputLayout::Q_PAGED_KV
23122308
: AttentionInputLayout::PACKED_QKV;

Diff for: cpp/tensorrt_llm/thop/attentionOp.cpp

+10-18
Original file line numberDiff line numberDiff line change
@@ -190,20 +190,15 @@ class Runner : public RunnerBase
190190
int const cyclic_attention_window_size = attention_window_size;
191191
bool const can_use_one_more_block = beam_width > 1;
192192

193-
int max_blocks_per_sequence = kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0;
194-
int32_t const pool_index = kv_cache_block_offsets.has_value()
195-
? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>()
196-
: 0;
197-
int32_t const layer_idx_in_cache_pool = kv_cache_block_offsets.has_value()
198-
? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>()
199-
: 0;
200-
KVBlockArray::DataType* block_offsets = static_cast<KVBlockArray::DataType*>(kv_cache_block_offsets.has_value()
201-
? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
202-
: nullptr);
203-
KVBlockArray::DataType* host_block_offsets
204-
= static_cast<KVBlockArray::DataType*>(host_kv_cache_block_offsets.has_value()
205-
? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
206-
: nullptr);
193+
int max_blocks_per_sequence = op.useKVCache() ? kv_cache_block_offsets.value().size(-1) : 0;
194+
int32_t const pool_index
195+
= op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>() : 0;
196+
int32_t const layer_idx_in_cache_pool
197+
= op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>() : 0;
198+
KVBlockArray::DataType* block_offsets = static_cast<KVBlockArray::DataType*>(
199+
op.useKVCache() ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr);
200+
KVBlockArray::DataType* host_block_offsets = static_cast<KVBlockArray::DataType*>(
201+
op.useKVCache() ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr);
207202

208203
auto const cache_elem_size = (op.mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
209204
auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize;
@@ -434,10 +429,7 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
434429
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
435430
op->mUseKVCache = use_kv_cache;
436431
op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache
437-
if (tokens_per_block.has_value())
438-
{
439-
op->mTokensPerBlock = tokens_per_block.value();
440-
}
432+
op->mTokensPerBlock = tokens_per_block.value_or(0);
441433
op->mMaxContextLength = max_context_length;
442434
op->mQScaling = q_scaling;
443435
op->mPositionEmbeddingType

0 commit comments

Comments
 (0)