From 092be2a09be94329d01d0652d121b440cec4323a Mon Sep 17 00:00:00 2001 From: keehyun Date: Wed, 18 Dec 2024 13:18:01 +0900 Subject: [PATCH] Wrapper module around TRT + pytorch subgraphs (#3270) --- core/runtime/execute_engine.cpp | 20 +-- core/runtime/register_jit_hooks.cpp | 6 +- core/runtime/runtime.cpp | 6 +- core/runtime/runtime.h | 13 +- docsrc/user_guide/runtime.rst | 2 +- examples/dynamo/torch_export_cudagraphs.py | 56 +++++- py/torch_tensorrt/_compile.py | 7 +- .../runtime/_CudaGraphsTorchTensorRTModule.py | 159 ++++++++++++++++++ .../runtime/_PythonTorchTensorRTModule.py | 5 +- py/torch_tensorrt/runtime/__init__.py | 1 + py/torch_tensorrt/runtime/_cudagraphs.py | 82 +++++++-- .../runtime/_weight_streaming.py | 12 +- .../dynamo/runtime/test_002_cudagraphs_cpp.py | 32 +++- .../dynamo/runtime/test_002_cudagraphs_py.py | 33 +++- 14 files changed, 376 insertions(+), 58 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index e871cd3467..fa38d0adaf 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -94,6 +94,7 @@ bool _validate_shapes(std::vector inputs, c10::intrusive_ptr inputs, c10::intrusive_ptr compiled_engine, + bool cudagraphs_enabled, bool need_cudagraphs_record) { // this is a buffer to store shape tensor input addresses throughout the runtime scope std::list> inputShapeTensorValues; @@ -127,7 +128,7 @@ void setup_input_tensors( compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), "Error while setting the tensor address for shape inputs"); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers compiled_engine->input_buffers[i] = input_cpu; } @@ -147,7 +148,7 @@ void setup_input_tensors( TORCHTRT_CHECK( compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); TORCHTRT_CHECK( @@ -201,17 +202,17 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); compiled_engine->cudagraph.enable_debug_mode(); } - + bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); bool shape_changed = _validate_shapes(inputs, compiled_engine); // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( - CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed); + cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); bool need_cudagraphs_record = std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); - if (!CUDAGRAPHS_MODE || shape_changed) { + if (!cudagraphs_enabled || shape_changed) { compiled_engine->cudagraph.reset(); } @@ -273,8 +274,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record); - + setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -306,7 +306,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), @@ -346,7 +346,7 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.record(compiled_engine->caller_stream); caller_exec_complete.block(compiled_engine->engine_stream); - if (!CUDAGRAPHS_MODE) { + if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { @@ -377,7 +377,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { outputs[o].copy_(compiled_engine->output_buffers[o], false); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e5edcf9729..3ded080b1d 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -112,8 +112,10 @@ TORCH_LIBRARY(tensorrt, m) { m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; }); - m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; }); - m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; }); + m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; }); + m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void { + CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode); + }); m.def("set_logging_level", [](int64_t level) -> void { util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level)); }); diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index b933e081c7..82b2fb1517 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -8,7 +8,7 @@ namespace core { namespace runtime { bool MULTI_DEVICE_SAFE_MODE = false; -bool CUDAGRAPHS_MODE = false; +CudaGraphsMode CUDAGRAPHS_MODE = STANDARD; c10::optional get_most_compatible_device( const RTDevice& target_device, @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; } -bool get_cudagraphs_mode() { +CudaGraphsMode get_cudagraphs_mode() { return CUDAGRAPHS_MODE; } -void set_cudagraphs_mode(bool cudagraphs_mode) { +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) { CUDAGRAPHS_MODE = cudagraphs_mode; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 86ba331796..6f1436c745 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -18,7 +18,14 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "6"; extern bool MULTI_DEVICE_SAFE_MODE; -extern bool CUDAGRAPHS_MODE; + +typedef enum { + STANDARD = 0, + SUBGRAPH_CUDAGRAPHS, + WHOLE_GRAPH_CUDAGRAPHS, +} CudaGraphsMode; + +extern CudaGraphsMode CUDAGRAPHS_MODE; typedef enum { ABI_TARGET_IDX = 0, @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode(); void set_multi_device_safe_mode(bool multi_device_safe_mode); -bool get_cudagraphs_mode(); +CudaGraphsMode get_cudagraphs_mode(); -void set_cudagraphs_mode(bool cudagraphs_mode); +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode); class DeviceList { using DeviceMap = std::unordered_map; diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index c897ea1f78..8672fdebe4 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -86,7 +86,7 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume torch_tensorrt.runtime.set_cudagraphs_mode(False) # Enables Cudagraphs Mode, then resets the mode to its prior setting - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs(trt_module): ... In the current implementation, use of a new input shape (for instance in dynamic shape diff --git a/examples/dynamo/torch_export_cudagraphs.py b/examples/dynamo/torch_export_cudagraphs.py index db7041b94d..1671c7783d 100644 --- a/examples/dynamo/torch_export_cudagraphs.py +++ b/examples/dynamo/torch_export_cudagraphs.py @@ -11,9 +11,8 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch -import torchvision.models as models - import torch_tensorrt +import torchvision.models as models # %% # Compilation with `torch_tensorrt.compile` Using Default Settings @@ -47,8 +46,8 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # We can enable the cudagraphs API with a context manager -with torch_tensorrt.runtime.enable_cudagraphs(): - out_trt = opt(inputs) +with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module: + out_trt = cudagraphs_module(inputs) # Alternatively, we can set the cudagraphs mode for the session torch_tensorrt.runtime.set_cudagraphs_mode(True) @@ -64,6 +63,49 @@ inputs_2 = torch.randn((8, 3, 224, 224)).cuda() inputs_3 = torch.randn((4, 3, 224, 224)).cuda() -with torch_tensorrt.runtime.enable_cudagraphs(): - out_trt_2 = opt(inputs_2) - out_trt_3 = opt(inputs_3) +with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module: + out_trt_2 = cudagraphs_module(inputs_2) + out_trt_3 = cudagraphs_module(inputs_3) + +# %% +# Cuda graphs with module that contains graph breaks +# ---------------------------------- +# +# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional +# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous +# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced +# kernel launch overhead and improved execution efficiency, may be diminished. +# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs +# that can be executed efficiently, even in the presence of graph breaks. +# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire +# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads +# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the +# module is executed several times. This warm-up ensures that memory allocations and initializations are not +# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance. + + +class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + +model = SampleModel().eval().cuda() +input = torch.randn((1, 3, 224, 224)).to("cuda") + +# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module. +# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner. +opt_with_graph_break = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=[input], + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, +) + +# %% +# If module has graph breaks, whole submodules are recorded and replayed by cuda graphs +with torch_tensorrt.runtime.enable_cudagraphs( + opt_with_graph_break +) as cudagraphs_module: + cudagraphs_module(input) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9492b09402..302928a784 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -12,6 +12,9 @@ from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( + CudaGraphsTorchTensorRTModule, +) from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision @@ -586,7 +589,7 @@ def save( Save the model to disk in the specified output format. Arguments: - module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module + module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module inputs (torch.Tensor): Torch input tensors arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. @@ -594,6 +597,8 @@ def save( retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now. """ + if isinstance(module, CudaGraphsTorchTensorRTModule): + module = module.compiled_module module_type = _parse_module_type(module) accepted_formats = {"exported_program", "torchscript"} if arg_inputs is not None and not all( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py new file mode 100644 index 0000000000..e7afeef398 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +from typing import List, Optional, Sequence, Tuple + +import torch +import torch_tensorrt +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch_tensorrt.dynamo import partitioning + +logger = logging.getLogger(__name__) + + +class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc] + """This Wrapper runtime module is to record/replay whole cuda graph in sub modules + + Args: + compiled_module: Complied fx graphModule that will be wrapped + Returns: + Output tensor or tensor list + """ + + def __init__( + self, + compiled_module: torch.nn.Module, + ): + super(CudaGraphsTorchTensorRTModule, self).__init__() + self.compiled_module = compiled_module + self.inputs = partitioning.construct_submodule_inputs(compiled_module) + + self._input_buffers: List[torch.Tensor] = [] + self._output_buffers: List[torch.Tensor] = [] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.shape_key: Optional[str] = None + self.prev_cudagraphs_enabled = False + self._caller_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: Optional[torch.cuda.Stream] = None + self.warm_up() + + def warm_up(self) -> None: + """ + Warm up is necessary to ensure that memory allocations and initializations + are not recorded in cuda graphs + """ + with torch_tensorrt.logging.errors(): + with unset_fake_temporarily(): + inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.compiled_module(*inputs_tensor) + torch.cuda.current_stream().wait_stream(s) + + def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + """ + Validates the input shapes of the forward function has changed + And infer output shapes if dynamic input shape has changed. + """ + # Representation of input shapes to a given model + # Shapes are concatenated as so: + # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) + + if new_shape_key != self.shape_key: + logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}") + self.shape_key = new_shape_key + return True + + return False + + def __del__(self) -> None: + if self.cudagraph: + self.cudagraph.reset() + + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() + if cudagraphs_enabled: + shape_changed = self.validate_input_shapes(inputs) + # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change + need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed + self.prev_cudagraphs_enabled = cudagraphs_enabled + + if need_cudagraphs_record: + if self.cudagraph: + self.cudagraph.reset() + self._input_buffers = [None] * len(self.inputs) + + # Ensure inputs are available in all scopes and cast symbolic integers to Tensors + contiguous_inputs: List[torch.Tensor] = [ + ( + i.contiguous() + if isinstance(i, torch.Tensor) + else torch.tensor(i).cuda() + ) + for i in inputs + ] + assert len(contiguous_inputs) == len( + self.inputs + ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + + for i, _ in enumerate(self.inputs): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input[{i}] is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + + assert ( + contiguous_inputs[i].dtype == self.inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) + + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): + self._output_buffers = self.compiled_module( + *self._input_buffers + ) + + self.cudagraph.replay() # type: ignore + self._caller_stream.wait_stream(self._engine_stream) + + if isinstance(self._output_buffers, (list, tuple)): + output_buffers = self._output_buffers + else: + output_buffers = [self._output_buffers] + outputs = [output.clone() for output in output_buffers] + if len(outputs) == 1: + return outputs[0] + return outputs + else: + if self.cudagraph: + self.cudagraph.reset() + self.cudagraph = None + return self.compiled_module(*inputs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e70d90086e..8996a2c486 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -30,7 +30,7 @@ def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool): # Indicates whether pre-allocated output was enabled in the previous execute_engine self.old_pre_allocated_outputs = new_pre_allocated_output - def validate_states( + def set_runtime_states( self, new_cudagraphs: bool, new_pre_allocated_output: bool, @@ -347,7 +347,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] - with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") if self.profiling_enabled @@ -358,7 +357,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record, can_use_pre_allocated_outputs = ( - self.runtime_states.validate_states( + self.runtime_states.set_runtime_states( cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed ) ) diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 9960460b60..470074a377 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -5,6 +5,7 @@ from torch_tensorrt.runtime._cudagraphs import ( enable_cudagraphs, get_cudagraphs_mode, + get_whole_cudagraphs_mode, set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 9d1523ef2e..d1564cb4dc 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,13 +1,26 @@ import logging -from typing import Any +from typing import Any, Union import torch import torch_tensorrt +from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( + CudaGraphsTorchTensorRTModule, +) + + +class CudaGraphsMode: + # No cuda graphs + STANDARD = 0 + # Cuda graphs is applied to TRT module + SUBGRAPH_CUDAGRAPHS = 1 + # Internal mode to apply cuda graphs for wrapped runtime module + WHOLE_GRAPH_CUDAGRAPHS = 2 + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: _PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode() else: - _PY_RT_CUDAGRAPHS = False + _PY_RT_CUDAGRAPHS = CudaGraphsMode.STANDARD logger = logging.getLogger(__name__) @@ -16,19 +29,33 @@ def set_cudagraphs_mode(mode: bool) -> None: # Set new cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - _PY_RT_CUDAGRAPHS = mode + _PY_RT_CUDAGRAPHS = ( + CudaGraphsMode.SUBGRAPH_CUDAGRAPHS if mode else CudaGraphsMode.STANDARD + ) # Set new mode for C++ if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(mode) + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) logger.info(f"Set Cudagraphs usage to {mode}") +def get_whole_cudagraphs_mode() -> bool: + # check if whole cudagraphs mode is enabled or not + global _PY_RT_CUDAGRAPHS + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS: + return True + else: + return False + + def get_cudagraphs_mode() -> bool: # Get cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - return _PY_RT_CUDAGRAPHS # type: ignore + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.SUBGRAPH_CUDAGRAPHS: + return True + else: + return False class _CudagraphsContextManager(object): @@ -37,19 +64,50 @@ class _CudagraphsContextManager(object): Used to enable cudagraphs as a context manager """ - def __init__(self) -> None: + def __init__(self, compiled_module: torch.nn.Module) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS + self.compiled_module = compiled_module + + def __enter__(self) -> torch.nn.Module: + global _PY_RT_CUDAGRAPHS - def __enter__(self) -> "_CudagraphsContextManager": - # Enable cudagraphs - set_cudagraphs_mode(True) - return self + num_torch_module = 0 + num_trt_module = 0 + for name, _ in self.compiled_module.named_children(): + if "_run_on_acc" in name: + num_trt_module += 1 + elif "_run_on_gpu" in name: + num_torch_module += 1 + + if num_torch_module > 0: + # Set whole cudagraphs mode and returns wrapped module + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + + logger.debug( + "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" + ) + return CudaGraphsTorchTensorRTModule(self.compiled_module) + else: + if num_trt_module > 0: + logger.debug("No graph breaks detected, using runtime cudagraphs mode") + else: + logger.debug( + "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" + ) + # Enable cudagraphs for TRT submodule + set_cudagraphs_mode(True) + return self.compiled_module def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode set_cudagraphs_mode(self.old_mode) -def enable_cudagraphs() -> _CudagraphsContextManager: - return _CudagraphsContextManager() +def enable_cudagraphs( + compiled_module: Union[torch.fx.GraphModule, torch.nn.Module], +) -> _CudagraphsContextManager: + return _CudagraphsContextManager(compiled_module) diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3a33330fa1..42f02a02a8 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -1,8 +1,11 @@ import logging -from typing import Any +from typing import Any, Union import torch from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( + CudaGraphsTorchTensorRTModule, +) logger = logging.getLogger(__name__) @@ -12,9 +15,14 @@ class _WeightStreamingContextManager(object): Helper class used to setup weight streaming budget """ - def __init__(self, module: torch.fx.GraphModule) -> None: + def __init__( + self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule] + ) -> None: rt_mods = [] self.current_device_budget = 0 + + if isinstance(module, CudaGraphsTorchTensorRTModule): + module = module.compiled_module for name, rt_mod in module.named_children(): if "_run_on_acc" in name and isinstance( rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py b/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py index 8649ca8e84..92cd995505 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py @@ -30,7 +30,19 @@ def test_cudagraphs_off(self): self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) def test_cudagraphs_context(self): - with torch_tensorrt.runtime.enable_cudagraphs(): + class SampleModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.abs.default(input) + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + ) + with torch_tensorrt.runtime.enable_cudagraphs(optimized_model) as _: self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) @@ -54,9 +66,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate( @@ -95,9 +109,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate( @@ -144,9 +160,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate( diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py index 4bdcfbbef4..0a4629644d 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -26,7 +26,20 @@ def test_cudagraphs_off(self): self.assertFalse(torch_tensorrt.runtime.get_cudagraphs_mode()) def test_cudagraphs_context(self): - with torch_tensorrt.runtime.enable_cudagraphs(): + class SampleModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.abs.default(input) + + model = SampleModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + optimized_model = torch_tensorrt.compile( + model, + "torch_compile", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + with torch_tensorrt.runtime.enable_cudagraphs(optimized_model) as _: self.assertTrue(torch_tensorrt.runtime.get_cudagraphs_mode()) self.assertFalse(torch_tensorrt.runtime.get_cudagraphs_mode()) @@ -53,9 +66,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate( @@ -96,9 +111,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate( @@ -144,9 +161,11 @@ def forward(self, x): result_samples = [] torch_results_samples = [] - with torch_tensorrt.runtime.enable_cudagraphs(): + with torch_tensorrt.runtime.enable_cudagraphs( + optimized_model + ) as cudagraphs_module: for i in inputs: - result_samples.append(optimized_model(i).detach().cpu()) + result_samples.append(cudagraphs_module(i).detach().cpu()) torch_results_samples.append(fx_graph(i).detach().cpu()) for i, (optimized_model_results, torch_model_results) in enumerate(