Skip to content

Commit

Permalink
Wrapper module around TRT + pytorch subgraphs (#3270)
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna authored Dec 18, 2024
1 parent bd11733 commit 092be2a
Show file tree
Hide file tree
Showing 14 changed files with 376 additions and 58 deletions.
20 changes: 10 additions & 10 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
void setup_input_tensors(
std::vector<at::Tensor> inputs,
c10::intrusive_ptr<TRTEngine> compiled_engine,
bool cudagraphs_enabled,
bool need_cudagraphs_record) {
// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
Expand Down Expand Up @@ -127,7 +128,7 @@ void setup_input_tensors(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
Expand All @@ -147,7 +148,7 @@ void setup_input_tensors(
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
Expand Down Expand Up @@ -201,17 +202,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
compiled_engine->cudagraph.enable_debug_mode();
}

bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
bool shape_changed = _validate_shapes(inputs, compiled_engine);

// Whether cudagraphs needs to record the graph on this pass
auto result = compiled_engine->runtime_states.set_runtime_states(
CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed);
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);

bool need_cudagraphs_record = std::get<0>(result);
bool can_use_pre_allocated_outputs = std::get<1>(result);

if (!CUDAGRAPHS_MODE || shape_changed) {
if (!cudagraphs_enabled || shape_changed) {
compiled_engine->cudagraph.reset();
}

Expand Down Expand Up @@ -273,8 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);

setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand Down Expand Up @@ -306,7 +306,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
Expand Down Expand Up @@ -346,7 +346,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);

if (!CUDAGRAPHS_MODE) {
if (!cudagraphs_enabled) {
// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
} else {
Expand Down Expand Up @@ -377,7 +377,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
outputs[o].copy_(compiled_engine->output_buffers[o], false);
Expand Down
6 changes: 4 additions & 2 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;
bool CUDAGRAPHS_MODE = false;
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
Expand Down Expand Up @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
}

bool get_cudagraphs_mode() {
CudaGraphsMode get_cudagraphs_mode() {
return CUDAGRAPHS_MODE;
}

void set_cudagraphs_mode(bool cudagraphs_mode) {
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
CUDAGRAPHS_MODE = cudagraphs_mode;
}

Expand Down
13 changes: 10 additions & 3 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

typedef enum {
STANDARD = 0,
SUBGRAPH_CUDAGRAPHS,
WHOLE_GRAPH_CUDAGRAPHS,
} CudaGraphsMode;

extern CudaGraphsMode CUDAGRAPHS_MODE;

typedef enum {
ABI_TARGET_IDX = 0,
Expand Down Expand Up @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();

void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_cudagraphs_mode();
CudaGraphsMode get_cudagraphs_mode();

void set_cudagraphs_mode(bool cudagraphs_mode);
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
Expand Down
2 changes: 1 addition & 1 deletion docsrc/user_guide/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume
torch_tensorrt.runtime.set_cudagraphs_mode(False)
# Enables Cudagraphs Mode, then resets the mode to its prior setting
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(trt_module):
...
In the current implementation, use of a new input shape (for instance in dynamic shape
Expand Down
56 changes: 49 additions & 7 deletions examples/dynamo/torch_export_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torchvision.models as models

import torch_tensorrt
import torchvision.models as models

# %%
# Compilation with `torch_tensorrt.compile` Using Default Settings
Expand Down Expand Up @@ -47,8 +46,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# We can enable the cudagraphs API with a context manager
with torch_tensorrt.runtime.enable_cudagraphs():
out_trt = opt(inputs)
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt = cudagraphs_module(inputs)

# Alternatively, we can set the cudagraphs mode for the session
torch_tensorrt.runtime.set_cudagraphs_mode(True)
Expand All @@ -64,6 +63,49 @@
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()

with torch_tensorrt.runtime.enable_cudagraphs():
out_trt_2 = opt(inputs_2)
out_trt_3 = opt(inputs_3)
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt_2 = cudagraphs_module(inputs_2)
out_trt_3 = cudagraphs_module(inputs_3)

# %%
# Cuda graphs with module that contains graph breaks
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
# that can be executed efficiently, even in the presence of graph breaks.
# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.


class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)

# %%
# If module has graph breaks, whole submodules are recorded and replayed by cuda graphs
with torch_tensorrt.runtime.enable_cudagraphs(
opt_with_graph_break
) as cudagraphs_module:
cudagraphs_module(input)
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
Expand Down Expand Up @@ -586,14 +589,16 @@ def save(
Save the model to disk in the specified output format.
Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if arg_inputs is not None and not all(
Expand Down
Loading

0 comments on commit 092be2a

Please sign in to comment.