diff --git a/truss/build.py b/truss/build.py index e7202c5ba..33c7acddf 100644 --- a/truss/build.py +++ b/truss/build.py @@ -6,11 +6,17 @@ import yaml +from truss.config.trt_llm import ( + CheckpointRepository, + CheckpointSource, + TRTLLMConfiguration, + TrussTRTLLMBuildConfiguration, +) from truss.constants import CONFIG_FILE, TEMPLATES_DIR, TRUSS from truss.docker import kill_containers from truss.model_inference import infer_python_version, map_to_supported_python_version from truss.notebook import is_notebook_or_ipython -from truss.truss_config import Build, TrussConfig +from truss.truss_config import Accelerator, AcceleratorSpec, Build, TrussConfig from truss.truss_handle import TrussHandle from truss.util.path import build_truss_target_directory, copy_tree_path @@ -54,6 +60,24 @@ def populate_target_directory( return target_directory_path_typed +def set_trtllm_engine_builder_config(config): + config.resources.accelerator = AcceleratorSpec( + accelerator=Accelerator("A10G"), count=1 + ) + config.resources.use_gpu = True + trt_llm_build = TrussTRTLLMBuildConfiguration( + base_model="llama", + max_input_len=1024, + max_output_len=1024, + max_batch_size=1, + max_beam_width=1, + checkpoint_repository=CheckpointRepository( + source=CheckpointSource("HF"), repo="" + ), + ) + config.trt_llm = TRTLLMConfiguration(build=trt_llm_build) + + def init( target_directory: str, data_files: Optional[List[str]] = None, @@ -77,12 +101,19 @@ def init( python_version=map_to_supported_python_version(infer_python_version()), ) - if build_config: + if build_config and build_config.model_server.value != "TRT_LLM_BUILDER": config.build = build_config + if build_config.model_server.value == "TRT_LLM_BUILDER": + template = "trtllm-engine-builder" + set_trtllm_engine_builder_config(config) + else: + template = "custom" + target_directory_path = populate_target_directory( config=config, target_directory_path=target_directory, + template=template, populate_dirs=True, ) diff --git a/truss/config/trt_llm.py b/truss/config/trt_llm.py index 42b23d499..d6ae40407 100644 --- a/truss/config/trt_llm.py +++ b/truss/config/trt_llm.py @@ -93,7 +93,9 @@ def _validate_minimum_required_configuration(self): if not self.serve and not self.build: raise ValueError("Either serve or build configurations must be provided") if self.serve and self.build: - raise ValueError("Both serve and build configurations cannot be provided") + raise ValueError( + "One of serve XOR build configurations must be provided, not both" + ) if self.serve is not None: if (self.serve.engine_repository is None) ^ ( self.serve.tokenizer_repository is None diff --git a/truss/templates/trtllm-engine-builder/__init__.py b/truss/templates/trtllm-engine-builder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/truss/templates/trtllm-engine-builder/model/model.py b/truss/templates/trtllm-engine-builder/model/model.py new file mode 100644 index 000000000..fbd98ce5b --- /dev/null +++ b/truss/templates/trtllm-engine-builder/model/model.py @@ -0,0 +1,30 @@ +""" +The `Model` class is allows you to customize the behavior of your TensorRT-LLM engine. + +The main methods to implement here are: +* `load`: runs exactly once when the model server is spun up or patched and loads the + model onto the model server. Include any logic for initializing your model server. +* `predict`: runs every time the model server is called. Include any logic for model + inference and return the model output. + +See https://docs.baseten.co/performance/engine-builder-customization for more. +""" + + +class Model: + def __init__(self, trt_llm, **kwargs): + # Uncomment the following to get access + # to various parts of the Truss config. + + # self._data_dir = kwargs["data_dir"] + # self._config = kwargs["config"] + # self._secrets = kwargs["secrets"] + self._engine = trt_llm["engine"] + + def load(self): + # Load + pass + + async def predict(self, model_input): + # Run model inference here + return await self._engine.predict(model_input) diff --git a/truss/truss_config.py b/truss/truss_config.py index fa09d2ff4..60b5447e5 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -174,6 +174,7 @@ class ModelServer(Enum): TrussServer = "TrussServer" TRT_LLM = "TRT_LLM" + TRT_LLM_BUILDER = "TRT_LLM_BUILDER" @dataclass