Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add save API for torch-trt compiled models #2691

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
<< std::endl;
}
ss << " }" << std::endl;
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
// clang-format on
Expand Down
71 changes: 35 additions & 36 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
:undoc-members:
:show-inheritance:

Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.

Dynamo IR
-------------

The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
The `output_format` can take the following options
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
specifying the `output_format` flag. Here are the options `output_format` will accept

* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
* `torchscript` (or) `ts` : This returns a TorchScript module
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.

a) Torchscript
a) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
Here's an example usage

.. code-block:: python

Expand All @@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ts is a torch.jit.ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
torch.jit.save(trt_ts, "trt_model.ts")
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model = torch.export.load("trt.ep").module()
model(*inputs)

b) ExportedProgram
b) Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.export.ExportedProgram object
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch.export.save(trt_ep, "trt_model.ep")
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)

# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model = torch.jit.load("trt.ts").cuda()
model(*inputs)

c) GraphModule
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
exported into `ExportedProgram` objects

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")

Torchscript IR
-------------
Expand All @@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)


Loading the models
--------------------

We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.

Here's an example usage

.. code-block:: python

import torch
import torch_tensorrt

# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)
111 changes: 107 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.fx
import torch_tensorrt.dynamo
import torch_tensorrt.ts
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
Expand All @@ -26,10 +27,7 @@

logger = logging.getLogger(__name__)

__all__ = [
"compile",
"convert_method_to_trt_engine",
]
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]


def _non_fx_input_interface(
Expand Down Expand Up @@ -332,3 +330,108 @@ def convert_method_to_trt_engine(
)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


def load(file_path: str = "") -> Any:
"""
Load either a Torchscript model or ExportedProgram. Autodetect the type using
try, except
"""
try:
logger.debug("Loading the provided file using torch.jit.load()")
ts_module = torch.jit.load(file_path)
return ts_module
except Exception:
logger.debug(
"Loading the provided file via torch.jit.load() failed with the following error",
exc_info=True,
)
pass

try:
logger.debug("Loading the provided file using torch.export.load()")
exp_program = torch.export.load(file_path)
return exp_program
except Exception:
logger.debug(
"Loading the provided file via torch.export.load() failed with the following error",
exc_info=True,
)
raise ValueError(
"The file doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
)


def save(
module: Any,
file_path: str = "",
*,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
retrace: bool = False,
) -> None:
"""
Save the model to disk in the specified output format.
Arguments:
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
inputs (torch.Tensor): Torch input tensors
output_format: Format to save the model. Options include exported_program | torchscript.
retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if inputs is not None and not all(
isinstance(input, torch.Tensor) for input in inputs
):
raise ValueError(
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
)
if output_format not in accepted_formats:
raise ValueError(
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
)
if not file_path:
raise ValueError("File path cannot be empty. Please provide a valid file path")

if module_type == _ModuleType.nn:
raise ValueError(
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
)
elif module_type == _ModuleType.ts:
if output_format == "exported_program":
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
)
else:
torch.jit.save(module, file_path)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
)
else:
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(module, inputs)
torch.jit.save(module_ts, file_path)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, inputs)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(inputs), strict=False
)
torch.export.save(exp_program, file_path)
9 changes: 2 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
OUTPUT_FORMAT,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
Expand All @@ -48,7 +47,6 @@
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo._exporter import export
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
Expand Down Expand Up @@ -102,9 +100,8 @@ def compile(
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = DRYRUN,
hardware_compatible: bool = HARDWARE_COMPATIBLE,
output_format: str = OUTPUT_FORMAT,
**kwargs: Any,
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler
Expand Down Expand Up @@ -246,14 +243,12 @@ def compile(
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"output_format": output_format,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
trt_gm = compile_module(gm, inputs, settings)
trt_result = export(trt_gm, torch_inputs, output_format)
return trt_result
return trt_gm


def compile_module(
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
OUTPUT_FORMAT = "exported_program"


def default_device() -> Device:
Expand Down
Loading
Loading