|
19 | 19 | os.environ.get("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION", "False") == "True"
|
20 | 20 | )
|
21 | 21 |
|
22 |
| -# Workaround for briton import without truss installed. |
23 |
| -try: |
24 |
| - from truss.base.truss_config import Resources |
25 |
| -except ImportError: |
26 |
| - |
27 |
| - class Resources(BaseModel): # type: ignore |
28 |
| - pass |
29 |
| - |
30 | 22 |
|
31 | 23 | class TrussTRTLLMModel(str, Enum):
|
32 | 24 | ENCODER = "encoder"
|
@@ -156,19 +148,33 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
|
156 | 148 | TrussTRTLLMPluginConfiguration()
|
157 | 149 | )
|
158 | 150 | num_builder_gpus: Optional[int] = None
|
159 |
| - build_resources: Optional[Resources] = None |
| 151 | + build_resources: Optional[Any] = None |
160 | 152 | speculator: Optional[TrussSpeculatorConfiguration] = None
|
161 | 153 |
|
162 | 154 | class Config:
|
163 | 155 | extra = "forbid"
|
164 | 156 |
|
165 | 157 | def __init__(self, **data):
|
| 158 | + data = self.parse_build_resources(data) |
166 | 159 | super().__init__(**data)
|
167 | 160 | self._validate_kv_cache_flags()
|
168 | 161 | self._validate_speculator_config()
|
169 | 162 | self._bei_specfic_migration()
|
170 | 163 | self._depreacate_num_builder_gpus()
|
171 | 164 |
|
| 165 | + @staticmethod |
| 166 | + def parse_build_resources(data): |
| 167 | + build_resources = data.get("build_resources") |
| 168 | + if build_resources: |
| 169 | + try: |
| 170 | + from truss.base.truss_config import Resources |
| 171 | + |
| 172 | + print("build_resources", build_resources) |
| 173 | + data["build_resources"] = Resources.from_dict(build_resources) |
| 174 | + except ImportError: |
| 175 | + pass |
| 176 | + return data |
| 177 | + |
172 | 178 | def _depreacate_num_builder_gpus(self):
|
173 | 179 | if self.num_builder_gpus:
|
174 | 180 | logger.warning(
|
|
0 commit comments