Skip to content

Commit af1c061

Browse files
committed
issue on truss config
1 parent 43e5883 commit af1c061

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

truss/base/trt_llm_config.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
os.environ.get("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION", "False") == "True"
2020
)
2121

22-
# Workaround for briton import without truss installed.
23-
try:
24-
from truss.base.truss_config import Resources
25-
except ImportError:
26-
27-
class Resources(BaseModel): # type: ignore
28-
pass
29-
3022

3123
class TrussTRTLLMModel(str, Enum):
3224
ENCODER = "encoder"
@@ -156,19 +148,33 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
156148
TrussTRTLLMPluginConfiguration()
157149
)
158150
num_builder_gpus: Optional[int] = None
159-
build_resources: Optional[Resources] = None
151+
build_resources: Optional[Any] = None
160152
speculator: Optional[TrussSpeculatorConfiguration] = None
161153

162154
class Config:
163155
extra = "forbid"
164156

165157
def __init__(self, **data):
158+
data = self.parse_build_resources(data)
166159
super().__init__(**data)
167160
self._validate_kv_cache_flags()
168161
self._validate_speculator_config()
169162
self._bei_specfic_migration()
170163
self._depreacate_num_builder_gpus()
171164

165+
@staticmethod
166+
def parse_build_resources(data):
167+
build_resources = data.get("build_resources")
168+
if build_resources:
169+
try:
170+
from truss.base.truss_config import Resources
171+
172+
print("build_resources", build_resources)
173+
data["build_resources"] = Resources.from_dict(build_resources)
174+
except ImportError:
175+
pass
176+
return data
177+
172178
def _depreacate_num_builder_gpus(self):
173179
if self.num_builder_gpus:
174180
logger.warning(

truss/tests/trt_llm/test_trt_llm_config.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def test_trt_llm_chunked_prefill_fix(trtllm_config):
7474

7575

7676
def test_trt_llm_builder_gpus(trtllm_config):
77+
db_dict = {"cpu": "1", "accelerator": "A10G:2", "memory": "1G", "node_count": None}
78+
trtllm_config["trt_llm"]["build"]["build_resources"] = db_dict
79+
7780
trt_llm_config = TRTLLMConfiguration(**trtllm_config["trt_llm"])
78-
trt_llm_config.build.build_resources = Resources(
79-
cpu=1, accelerator="A10G:2", memory="1G"
80-
)
81-
assert trt_llm_config.build.build_resources.accelerator == "A10G:2"
81+
build_resources = Resources.from_dict(db_dict)
82+
trt_llm_config.build.build_resources = build_resources
83+
assert trt_llm_config.build.build_resources == build_resources
8284

8385

8486
def test_trt_llm_lookahead_decoding(trtllm_config):

0 commit comments

Comments
 (0)