Skip to content

Commit c104768

Browse files
authored
Update TensorRT-LLM backend (#635)
1 parent 84ab8f6 commit c104768

32 files changed

+1529
-253
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ repos:
4747
exclude: tools/dataset/
4848
args:
4949
- --skip=".git,tensorrt_llm"
50+
- --exclude-file=all_models/whisper/whisper_bls/1/tokenizer.py

all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ def convert_request(request, exclude_input_from_output, decoupled):
343343
return requests
344344

345345

346-
def convert_response(response, batch_index):
346+
def convert_response(response, batch_index, batch_size, num_return_sequences):
347+
347348
if response.has_error():
348349
return pb_utils.InferenceResponse(output_tensors=[],
349350
error=pb_utils.TritonError(
@@ -356,40 +357,50 @@ def convert_response(response, batch_index):
356357
-1, np.int32)
357358
for idx, beam in enumerate(result.output_token_ids):
358359
output_ids[0, idx, :len(beam)] = beam
360+
359361
output_tensors = [
360362
pb_utils.Tensor("output_ids", output_ids),
361363
pb_utils.Tensor("sequence_length", beam_lengths),
362364
]
363-
output_tensors.append(
364-
pb_utils.Tensor(
365-
"cum_log_probs",
366-
np.expand_dims(np.array(result.cum_log_probs, np.float32), 0)
367-
if result.cum_log_probs is not None else np.zeros(
368-
(1, 1), np.float32)))
369-
output_tensors.append(
370-
pb_utils.Tensor(
371-
"output_log_probs",
372-
np.expand_dims(np.array(result.log_probs, np.float32), 0) if
373-
result.log_probs is not None else np.zeros((1, 1, 1), np.float32)))
374-
output_tensors.append(
375-
pb_utils.Tensor(
376-
"context_logits",
377-
np.expand_dims(np.array(result.context_logits, np.float32), 0)
378-
if result.context_logits is not None else np.zeros(
379-
(1, 1, 1), np.float32)))
380-
output_tensors.append(
381-
pb_utils.Tensor(
382-
"generation_logits",
383-
np.expand_dims(np.array(result.generation_logits, np.float32), 0)
384-
if result.generation_logits is not None else np.zeros(
385-
(1, 1, 1, 1), np.float32)))
386-
output_tensors.append(
387-
pb_utils.Tensor("batch_index",
388-
np.expand_dims(np.array([batch_index], np.int32), 0)))
389-
output_tensors.append(
390-
pb_utils.Tensor(
391-
"sequence_index",
392-
np.expand_dims(np.array([result.sequence_index], np.int32), 0)))
365+
366+
if result.cum_log_probs is not None:
367+
output_tensors.append(
368+
pb_utils.Tensor(
369+
"cum_log_probs",
370+
np.expand_dims(np.array(result.cum_log_probs, np.float32), 0)))
371+
372+
if result.log_probs is not None:
373+
output_tensors.append(
374+
pb_utils.Tensor(
375+
"output_log_probs",
376+
np.expand_dims(np.array(result.log_probs, np.float32), 0)))
377+
378+
if result.context_logits is not None:
379+
output_tensors.append(
380+
pb_utils.Tensor(
381+
"context_logits",
382+
np.expand_dims(np.array(result.context_logits, np.float32),
383+
0)))
384+
385+
if result.generation_logits is not None:
386+
output_tensors.append(
387+
pb_utils.Tensor(
388+
"generation_logits",
389+
np.expand_dims(np.array(result.generation_logits, np.float32),
390+
0)))
391+
392+
if batch_size > 1:
393+
output_tensors.append(
394+
pb_utils.Tensor(
395+
"batch_index",
396+
np.expand_dims(np.array([batch_index], np.int32), 0)))
397+
398+
if num_return_sequences > 1:
399+
output_tensors.append(
400+
pb_utils.Tensor(
401+
"sequence_index",
402+
np.expand_dims(np.array([result.sequence_index], np.int32),
403+
0)))
393404

394405
return pb_utils.InferenceResponse(output_tensors), result.is_final
395406

@@ -466,6 +477,8 @@ def get_kv_cache_config(self, model_config):
466477
"free_gpu_memory_fraction":
467478
get_parameter(model_config, "kv_cache_free_gpu_mem_fraction",
468479
float),
480+
"cross_kv_cache_fraction":
481+
get_parameter(model_config, "cross_kv_cache_fraction", float),
469482
"host_cache_size":
470483
get_parameter(model_config, "kv_cache_host_memory_bytes", int),
471484
"onboard_blocks":
@@ -876,11 +889,14 @@ def execute(self, requests):
876889

877890
with self.lock:
878891
request_ids = self.executor.enqueue_requests(executor_requests)
879-
for req_id, triton_req_id, triton_user_id, triton_request, batch_index in zip(
892+
for req_id, triton_req_id, triton_user_id, executor_request, triton_request, batch_index in zip(
880893
request_ids, triton_req_ids, triton_user_ids,
881-
triton_requests, batch_indices):
894+
executor_requests, triton_requests, batch_indices):
895+
882896
self.req_id_to_request_data[
883-
req_id] = triton_req_id, triton_user_id, batch_index, triton_request.get_response_sender(
897+
req_id] = triton_req_id, triton_user_id, batch_index, len(
898+
batch_indices
899+
), executor_request.num_return_sequences, triton_request.get_response_sender(
884900
)
885901
self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
886902
if triton_user_id is not None and triton_user_id != "":
@@ -897,11 +913,11 @@ def awaiter_loop(self):
897913
with self.lock:
898914
if req_id not in self.req_id_to_request_data:
899915
continue
900-
triton_req_id, triton_user_id, batch_index, response_sender = self.req_id_to_request_data[
916+
triton_req_id, triton_user_id, batch_index, batch_size, num_return_sequences, response_sender = self.req_id_to_request_data[
901917
req_id]
902918

903919
triton_response, is_final = convert_response(
904-
response, batch_index)
920+
response, batch_index, batch_size, num_return_sequences)
905921

906922
triton_request_final = False
907923
if is_final:
@@ -935,7 +951,7 @@ def cancellation_loop(self):
935951
time.sleep(self.cancellation_check_period_ms / 1000.0)
936952
with self.lock:
937953
for req_id, (triton_req_id, triton_user_id, batch_index,
938-
response_sender
954+
batch_size, num_return_sequences, response_sender
939955
) in self.req_id_to_request_data.items():
940956
if response_sender.is_cancelled():
941957
self.executor.cancel_request(req_id)

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ input [
4444
data_type: TYPE_INT32
4545
dims: [ -1 ]
4646
allow_ragged_batch: true
47+
optional: true
48+
},
49+
{
50+
name: "encoder_input_features"
51+
data_type: TYPE_FP16
52+
dims: [ -1, -1 ]
53+
allow_ragged_batch: true
54+
optional: true
55+
},
56+
{
57+
name: "encoder_output_lengths"
58+
data_type: TYPE_INT32
59+
dims: [ 1 ]
60+
reshape: { shape: [ ] }
61+
optional: true
4762
},
4863
{
4964
name: "input_lengths"
@@ -465,6 +480,12 @@ parameters: {
465480
string_value: "${kv_cache_free_gpu_mem_fraction}"
466481
}
467482
}
483+
parameters: {
484+
key: "cross_kv_cache_fraction"
485+
value: {
486+
string_value: "${cross_kv_cache_fraction}"
487+
}
488+
}
468489
parameters: {
469490
key: "kv_cache_host_memory_bytes"
470491
value: {

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,16 @@ def postprocess(self, gen_response: GenerationResponse,
358358
)
359359

360360
batch_index = gen_response.batch_index
361-
if batch_index.ndim != 2:
362-
raise Exception("Expected batch_index tensor to have 2 dims.")
363-
if batch_index.shape[0] != 1:
364-
raise Exception("Expected batch size of 1")
365-
if batch_index.shape[1] != 1:
366-
raise Exception("Expected only one batch_index")
367-
368-
batch_index = batch_index[0][0]
361+
if batch_index is not None:
362+
if batch_index.ndim != 2:
363+
raise Exception(
364+
"Expected batch_index tensor to have 2 dims.")
365+
if batch_index.shape[0] != 1:
366+
raise Exception("Expected batch size of 1")
367+
if batch_index.shape[1] != 1:
368+
raise Exception("Expected only one batch_index")
369+
370+
batch_index = batch_index[0][0] if batch_index is not None else 0
369371

370372
self._accumulated_tokens[batch_index] = new_tokens if (
371373
self._accumulated_tokens[batch_index] is None

all_models/tests/test_python_backend.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,10 @@ def test_convert_request_invalid():
539539

540540
def test_convert_response(trtllm_response: trtllm.Response):
541541
batch_index = 2
542-
response, is_final = convert_response(trtllm_response, batch_index)
542+
batch_size = 3
543+
num_return_sequences = 1
544+
response, is_final = convert_response(trtllm_response, batch_index,
545+
batch_size, num_return_sequences)
543546
assert is_final == True
544547
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3]
545548
])).all()
@@ -559,27 +562,30 @@ def test_convert_response(trtllm_response: trtllm.Response):
559562

560563
def test_convert_response_minimal(trtllm_response_minimal: trtllm.Response):
561564
batch_index = 2
562-
response, is_final = convert_response(trtllm_response_minimal, batch_index)
565+
batch_size = 3
566+
num_return_sequences = 1
567+
response, is_final = convert_response(trtllm_response_minimal, batch_index,
568+
batch_size, num_return_sequences)
563569
assert is_final == False
564570
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3]
565571
])).all()
566572
assert (response.tensors["sequence_length"].as_numpy() == np.array(
567573
[[3]])).all()
568-
assert (response.tensors["cum_log_probs"].as_numpy() == np.zeros(
569-
(1, 1), np.float32)).all()
570-
assert (response.tensors["output_log_probs"].as_numpy() == np.zeros(
571-
(1, 1, 1), np.float32)).all()
572-
assert (response.tensors["context_logits"].as_numpy() == np.zeros(
573-
(1, 1, 1), np.float32)).all()
574-
assert (response.tensors["generation_logits"].as_numpy() == np.zeros(
575-
(1, 1, 1, 1), np.float32)).all()
574+
assert "cum_log_probs" not in response.tensors
575+
assert "output_log_probs" not in response.tensors
576+
assert "output_log_probs" not in response.tensors
577+
assert "context_logits" not in response.tensors
578+
assert "generation_logits" not in response.tensors
576579
assert (response.tensors["batch_index"].as_numpy() == np.array(
577580
[[batch_index]])).all()
578581

579582

580583
def test_convert_response_error(trtllm_response_error: trtllm.Response):
581584
batch_index = 2
582-
response, is_final = convert_response(trtllm_response_error, batch_index)
585+
batch_size = 3
586+
num_return_sequences = 1
587+
response, is_final = convert_response(trtllm_response_error, batch_index,
588+
batch_size, num_return_sequences)
583589
assert is_final == True
584590
assert response.has_error() and response.error.message == "internal error"
585591

@@ -637,6 +643,7 @@ def model_config() -> Dict:
637643
"max_attention_window_size": "2",
638644
"sink_token_length": "3",
639645
"kv_cache_free_gpu_mem_fraction": "0.5",
646+
"cross_kv_cache_fraction": "0.5",
640647
"kv_cache_host_memory_bytes": "4",
641648
"kv_cache_onboard_blocks": "false",
642649
"gpu_device_ids": "0,1,2,3",
@@ -665,6 +672,7 @@ def test_get_executor_config(model_config: Dict):
665672
assert config.kv_cache_config.max_attention_window == [2]
666673
assert config.kv_cache_config.sink_token_length == 3
667674
assert config.kv_cache_config.free_gpu_memory_fraction == 0.5
675+
assert config.kv_cache_config.cross_kv_cache_fraction == 0.5
668676
assert config.kv_cache_config.host_cache_size == 4
669677
assert config.kv_cache_config.onboard_blocks == False
670678
assert config.parallel_config.device_ids == [0, 1, 2, 3]
@@ -707,6 +715,7 @@ def test_get_executor_config_minimal():
707715
assert config.kv_cache_config.max_attention_window is None
708716
assert config.kv_cache_config.sink_token_length is None
709717
assert config.kv_cache_config.free_gpu_memory_fraction is None
718+
assert config.kv_cache_config.cross_kv_cache_fraction is None
710719
assert config.kv_cache_config.host_cache_size is None
711720
assert config.kv_cache_config.onboard_blocks == True
712721
assert config.parallel_config is None
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
15+
import os
16+
from typing import Union
17+
18+
import numpy as np
19+
import torch
20+
import torch.nn.functional as F
21+
22+
23+
def mel_filters(device, n_mels: int = 128) -> torch.Tensor:
24+
"""
25+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
26+
Allows decoupling librosa dependency; saved using:
27+
28+
np.savez_compressed(
29+
"mel_filters.npz",
30+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
31+
)
32+
"""
33+
assert n_mels == 80 or n_mels == 128, f"Unsupported n_mels: {n_mels}"
34+
with np.load(os.path.join(os.path.dirname(__file__),
35+
"mel_filters.npz")) as f:
36+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
37+
38+
39+
def log_mel_spectrogram(
40+
audio: Union[torch.Tensor],
41+
filters: torch.Tensor,
42+
n_mels: int = 128,
43+
n_fft: int = 400,
44+
hop_length: int = 160,
45+
):
46+
"""
47+
Compute the log-Mel spectrogram of
48+
49+
Parameters
50+
----------
51+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
52+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
53+
54+
n_mels: int
55+
The number of Mel-frequency filters, only 80 or 128 is supported
56+
57+
filters: torch.Tensor
58+
59+
Returns
60+
-------
61+
torch.Tensor, shape = (128, n_frames)
62+
A Tensor that contains the Mel spectrogram
63+
"""
64+
window = torch.hann_window(n_fft).to(audio.device)
65+
stft = torch.stft(audio,
66+
n_fft,
67+
hop_length,
68+
window=window,
69+
return_complex=True)
70+
magnitudes = stft[..., :-1].abs()**2
71+
72+
mel_spec = filters @ magnitudes
73+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
74+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
75+
log_spec = (log_spec + 4.0) / 4.0
76+
# cast to float 16
77+
log_spec = log_spec.half()
78+
return log_spec
79+
80+
81+
class FeatureExtractor(torch.nn.Module):
82+
"""Your Python model must use the same class name. Every Python model
83+
that is created must have "TritonPythonModel" as the class name.
84+
"""
85+
86+
def __init__(self, n_mels: int = 128):
87+
self.device = torch.device("cuda")
88+
self.n_mels = n_mels
89+
self.filters = mel_filters(self.device, n_mels=self.n_mels)
90+
91+
def compute_feature(self, wav, target: int = 3000):
92+
mel = log_mel_spectrogram(wav, self.filters)
93+
if mel.shape[1] < target:
94+
mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
95+
if mel.shape[1] % 2:
96+
# pad to even length for remove_padding case, since conv1d requires even length
97+
mel = torch.nn.functional.pad(mel, (0, 1))
98+
mel = mel.unsqueeze(0)
99+
return mel

0 commit comments

Comments
 (0)