diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 9f93fe4b4e..9a04aba6de 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -453,6 +453,10 @@ std::vector TRTEngine::serialize() { return serialized_info; } +void TRTEngine::reset_captured_graph() { + cudagraph.reset(); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e9b1905610..2db640b6b1 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); + void reset_captured_graph(); #ifndef NDEBUG bool profile_execution = true; #else diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c05be4e8aa..cbe19b0af6 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) + .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 5af9b11a4b..9e54fbac3d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return False - def __del__(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self._reset_captured_graph() def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable @@ -119,8 +123,7 @@ def forward( shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set if need_cudagraphs_record: - if self.cudagraph: - self.cudagraph.reset() + self._reset_captured_graph() self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False @@ -196,7 +199,5 @@ def forward( return outputs[0] return outputs else: - if self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + self._reset_captured_graph() return self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 891d063ed3..6415ce11c3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def __del__(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self._reset_captured_graph() def setup_input_tensors( self, @@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed ) - if need_cudagraphs_reset and self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + if need_cudagraphs_reset: + self._reset_captured_graph() if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e6b6a21421..c3fe925eee 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes + def _reset_captured_graph(self) -> None: + self.engine.reset_captured_graph() + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index f481c5b2b8..1b6963fa50 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -142,6 +142,9 @@ def automatic_device_memory_budget_getter(self) -> Any: def infer_outputs(self, input_shapes: List[Any]) -> Any: pass + def reset_captured_graph(self) -> Any: + pass + def __setstate__(self, serialized_state: List[str]) -> Any: pass diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index c771564826..346132145e 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Union +from typing import Any, Optional, Union import torch import torch_tensorrt @@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module + self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None def __enter__(self) -> torch.nn.Module: global _PY_RT_CUDAGRAPHS @@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module: logger.debug( "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" ) - return CudaGraphsTorchTensorRTModule(self.compiled_module) + self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) + return self.cudagraphs_module else: if num_trt_module > 0: logger.debug("No graph breaks detected, using runtime cudagraphs mode") @@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module: def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode set_cudagraphs_mode(self.old_mode) + # __del__ is not entirely predictable, so we reset cudagraph here + if self.cudagraphs_module: + self.cudagraphs_module._reset_captured_graph() def enable_cudagraphs( diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3b11087fcb..0874d31d11 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int: int(streamable_bytes / total_bytes * requested_budget) for streamable_bytes in self.streamable_budget ] + if self.cuda_graphs_module: + self.cuda_graphs_module.is_weight_streaming_set = True + self.cuda_graphs_module._reset_captured_graph() + for i, (name, rt_mod) in enumerate(self.rt_mods): + rt_mod._reset_captured_graph() ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}") - if self.cuda_graphs_module: - self.cuda_graphs_module.is_weight_streaming_set = True return ws_budget_bytes def __setattr__(self, name: str, value: Any) -> None: