Skip to content

Cherrypick #3461 for release/2.7 #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ std::vector<std::string> TRTEngine::serialize() {
return serialized_info;
}

void TRTEngine::reset_captured_graph() {
cudagraph.reset();
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder {
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

void set_profiling_paths();
void reset_captured_graph();
#ifndef NDEBUG
bool profile_execution = true;
#else
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:

return False

def __del__(self) -> None:
def _reset_captured_graph(self) -> None:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

def __del__(self) -> None:
self._reset_captured_graph()

def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable
Expand All @@ -119,8 +123,7 @@ def forward(
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._reset_captured_graph()
self._input_buffers = [None] * len(inputs)

self.is_weight_streaming_set = False
Expand Down Expand Up @@ -196,7 +199,5 @@ def forward(
return outputs[0]
return outputs
else:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
self._reset_captured_graph()
return self.compiled_module(*args, **kwargs)
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
result.__setstate__(self.__getstate__())
return result

def __del__(self) -> None:
def _reset_captured_graph(self) -> None:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

def __del__(self) -> None:
self._reset_captured_graph()

def setup_input_tensors(
self,
Expand Down Expand Up @@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
if need_cudagraphs_reset:
self._reset_captured_graph()

if need_cudagraphs_record:
self._input_buffers = [None] * len(self.input_names)
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:

return budget_bytes

def _reset_captured_graph(self) -> None:
self.engine.reset_captured_graph()

def setup_engine(self) -> None:
"""
Setup engine for a module which has deferred engine setup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def automatic_device_memory_budget_getter(self) -> Any:
def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def reset_captured_graph(self) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass

Expand Down
9 changes: 7 additions & 2 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Union
from typing import Any, Optional, Union

import torch
import torch_tensorrt
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
global _PY_RT_CUDAGRAPHS
self.old_mode = _PY_RT_CUDAGRAPHS
self.compiled_module = compiled_module
self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None

def __enter__(self) -> torch.nn.Module:
global _PY_RT_CUDAGRAPHS
Expand Down Expand Up @@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module:
logger.debug(
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
)
return CudaGraphsTorchTensorRTModule(self.compiled_module)
self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module)
return self.cudagraphs_module
else:
if num_trt_module > 0:
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
Expand All @@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module:
def __exit__(self, *args: Any) -> None:
# Set cudagraphs back to old mode
set_cudagraphs_mode(self.old_mode)
# __del__ is not entirely predictable, so we reset cudagraph here
if self.cudagraphs_module:
self.cudagraphs_module._reset_captured_graph()


def enable_cudagraphs(
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
int(streamable_bytes / total_bytes * requested_budget)
for streamable_bytes in self.streamable_budget
]
if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
self.cuda_graphs_module._reset_captured_graph()

for i, (name, rt_mod) in enumerate(self.rt_mods):
rt_mod._reset_captured_graph()
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
Loading