Skip to content

Commit bd371f8

Browse files
committed
refactor: Resolve comments for Python code
Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata. Signed-off-by: Qixiang Lin <[email protected]>
1 parent f23e17a commit bd371f8

File tree

2 files changed

+28
-41
lines changed

2 files changed

+28
-41
lines changed

Diff for: tensorrt_llm/_torch/attention_backend/trtllm.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def plan(
205205
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
206206
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
207207
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
208-
host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
209-
host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
210-
host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
208+
host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU.
209+
host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory.
210+
host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case.
211211
workspace (torch.Tensor): An optional workspace tensor on GPU.
212212
cache_indirection (torch.Tensor): A tensor for beam search on GPU, its shape is (batch_size, beam_width, max_seqlen), for a sequence si, a beam bi and a token ti, the element cache_indirection[si][bi][ti] is an integer between 0 and beam_width-1 that indicates which path in the beam to read the K and V elements from in the KV cache.
213213
kv_scale_orig_quant (torch.Tensor): The tensor to store the scaling factor for quantization to INT8/FP8 in the KV cache, with shape (1) on GPU.
@@ -523,7 +523,21 @@ def __post_init__(self) -> None:
523523
def prepare(self) -> None:
524524

525525
if not self.is_dummy_attention and self.kv_cache_manager is None:
526-
self._as_no_cache_attention_metadata()
526+
# Convert the attention metadata to a TRT-LLM no cache attention metadata.
527+
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
528+
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no cache attention"
529+
530+
# setting kv cache params
531+
self.kv_cache_params = KVCacheParams(use_cache=False, )
532+
533+
# trtllm attn metadata prepare() requires this
534+
self.prompt_lens = self.context_lens
535+
536+
# set params that are used in wrapper.plan()
537+
self.kv_cache_block_offsets = None
538+
self.host_kv_cache_block_offsets = None
539+
self.block_ids_per_seq = None
540+
527541
prompt_lens = torch.tensor(
528542
self.prompt_lens,
529543
dtype=torch.int,
@@ -579,25 +593,6 @@ def prepare(self) -> None:
579593
assert self.kv_lens[:self.num_seqs].max(
580594
) <= self.kv_cache_manager.max_seq_len, f"Please set max_seq_len to at least {self.kv_lens[:self.num_seqs].max()} for kv cache manager."
581595

582-
def _as_no_cache_attention_metadata(self) -> None:
583-
"""
584-
Convert the attention metadata to a TRT-LLM no cache attention metadata.
585-
This is a private method and should not be directly called.
586-
"""
587-
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
588-
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no cache attention"
589-
590-
# setting kv cache params
591-
self.kv_cache_params = KVCacheParams(use_cache=False, )
592-
593-
# trtllm attn metadata prepare() requires this
594-
self.prompt_lens = self.context_lens
595-
596-
# set params that are used in wrapper.plan()
597-
self.kv_cache_block_offsets = None
598-
self.host_kv_cache_block_offsets = None
599-
self.block_ids_per_seq = None
600-
601596

602597
class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
603598

Diff for: tests/unittest/_torch/modeling/test_modeling_bert.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,16 @@ def test_bert_allclose_to_hf(self, scenario: Scenario):
105105
# Fill the metadata for tllm attn
106106
request_ids = [1]
107107
prompt_lens = [input_ids.size(-1)]
108-
kwargs = {
109-
"max_num_requests": 1,
110-
"max_num_tokens": 8192,
111-
"kv_cache_manager": None,
112-
"request_ids": request_ids,
113-
"prompt_lens": prompt_lens,
114-
}
115-
# if metadata_cls is TrtllmAttentionMetadata:
116-
# kwargs["is_no_cache"] = True
117-
# kwargs["max_seq_len"] = input_ids.size(-1)
118-
119-
attn_metadata = metadata_cls(**kwargs)
120-
attn_metadata.seq_lens = torch.tensor([input_ids.size(-1)],
121-
dtype=torch.int)
122-
attn_metadata.num_contexts = 1
108+
109+
attn_metadata = metadata_cls(
110+
max_num_requests=1,
111+
max_num_tokens=8192,
112+
kv_cache_manager=None,
113+
request_ids=request_ids,
114+
prompt_lens=prompt_lens,
115+
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
116+
num_contexts=1,
117+
)
123118
attn_metadata.max_seq_len = input_ids.size(-1)
124119
attn_metadata.prepare()
125120

@@ -129,9 +124,6 @@ def test_bert_allclose_to_hf(self, scenario: Scenario):
129124

130125
# Run inference
131126
with torch.inference_mode():
132-
if backend == 'TRTLLM':
133-
attn_metadata.prepare()
134-
#NOTE:attn_metadata.prepare is not needed for no cache case
135127
# TRT-LLM model forward
136128
tllm_outputs = tllm_model(
137129
input_ids=input_ids,

0 commit comments

Comments
 (0)