Skip to content

Commit bb7592f

Browse files
authored
feat: Add save API for torch-trt compiled models (#2691)
1 parent f39e89e commit bb7592f

File tree

11 files changed

+311
-261
lines changed

11 files changed

+311
-261
lines changed

core/runtime/TRTEngine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
241241
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
242242
<< std::endl;
243243
}
244-
ss << " }" << std::endl;
244+
ss << " ]" << std::endl;
245245
ss << " Device: " << device_info << std::endl;
246246
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
247247
// clang-format on

docsrc/user_guide/saving_models.rst

+35-36
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
12+
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.
1313

1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18-
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19-
The `output_format` can take the following options
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
19+
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

21-
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22-
* `torchscript` (or) `ts` : This returns a TorchScript module
23-
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
21+
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
22+
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
2423

25-
a) Torchscript
24+
a) ExportedProgram
2625
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2726

28-
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
27+
Here's an example usage
2928

3029
.. code-block:: python
3130
@@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`
3433
3534
model = MyModel().eval().cuda()
3635
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
37-
# trt_ts is a torch.jit.ScriptModule object
38-
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39-
torch.jit.save(trt_ts, "trt_model.ts")
36+
# trt_ep is a torch.fx.GraphModule object
37+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
38+
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)
4039
4140
# Later, you can load it and run inference
42-
model = torch.jit.load("trt_model.ts").cuda()
41+
model = torch.export.load("trt.ep").module()
4342
model(*inputs)
4443
45-
b) ExportedProgram
44+
b) Torchscript
4645
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4746

48-
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
49-
5047
.. code-block:: python
5148
5249
import torch
5350
import torch_tensorrt
5451
5552
model = MyModel().eval().cuda()
5653
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
57-
# trt_ep is a torch.export.ExportedProgram object
58-
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59-
torch.export.save(trt_ep, "trt_model.ep")
54+
# trt_gm is a torch.fx.GraphModule object
55+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
56+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
6057
6158
# Later, you can load it and run inference
62-
model = torch.export.load("trt_model.ep")
59+
model = torch.jit.load("trt.ts").cuda()
6360
model(*inputs)
6461
65-
c) GraphModule
66-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67-
68-
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69-
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70-
exported into `ExportedProgram` objects
71-
72-
.. code-block:: python
73-
74-
import torch
75-
import torch_tensorrt
76-
77-
model = MyModel().eval().cuda()
78-
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79-
# trt_gm is a torch.fx.GraphModule object
80-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
8162
8263
Torchscript IR
8364
-------------
@@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
9980
model = torch.jit.load("trt_model.ts").cuda()
10081
model(*inputs)
10182
83+
84+
Loading the models
85+
--------------------
86+
87+
We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
88+
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.
89+
90+
Here's an example usage
91+
92+
.. code-block:: python
93+
94+
import torch
95+
import torch_tensorrt
96+
97+
# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
98+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99+
model = torch_tensorrt.load(<file_path>).module()
100+
model(*inputs)

py/torch_tensorrt/_compile.py

+107-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.fx
9+
import torch_tensorrt.dynamo
910
import torch_tensorrt.ts
1011
from torch_tensorrt._enums import dtype
1112
from torch_tensorrt._Input import Input
@@ -26,10 +27,7 @@
2627

2728
logger = logging.getLogger(__name__)
2829

29-
__all__ = [
30-
"compile",
31-
"convert_method_to_trt_engine",
32-
]
30+
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
3331

3432

3533
def _non_fx_input_interface(
@@ -332,3 +330,108 @@ def convert_method_to_trt_engine(
332330
)
333331
else:
334332
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
333+
334+
335+
def load(file_path: str = "") -> Any:
336+
"""
337+
Load either a Torchscript model or ExportedProgram. Autodetect the type using
338+
try, except
339+
"""
340+
try:
341+
logger.debug("Loading the provided file using torch.jit.load()")
342+
ts_module = torch.jit.load(file_path)
343+
return ts_module
344+
except Exception:
345+
logger.debug(
346+
"Loading the provided file via torch.jit.load() failed with the following error",
347+
exc_info=True,
348+
)
349+
pass
350+
351+
try:
352+
logger.debug("Loading the provided file using torch.export.load()")
353+
exp_program = torch.export.load(file_path)
354+
return exp_program
355+
except Exception:
356+
logger.debug(
357+
"Loading the provided file via torch.export.load() failed with the following error",
358+
exc_info=True,
359+
)
360+
raise ValueError(
361+
"The file doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
362+
)
363+
364+
365+
def save(
366+
module: Any,
367+
file_path: str = "",
368+
*,
369+
output_format: str = "exported_program",
370+
inputs: Optional[Sequence[torch.Tensor]] = None,
371+
retrace: bool = False,
372+
) -> None:
373+
"""
374+
Save the model to disk in the specified output format.
375+
Arguments:
376+
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
377+
inputs (torch.Tensor): Torch input tensors
378+
output_format: Format to save the model. Options include exported_program | torchscript.
379+
retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
380+
This flag is experimental for now.
381+
"""
382+
module_type = _parse_module_type(module)
383+
accepted_formats = {"exported_program", "torchscript"}
384+
if inputs is not None and not all(
385+
isinstance(input, torch.Tensor) for input in inputs
386+
):
387+
raise ValueError(
388+
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
389+
)
390+
if output_format not in accepted_formats:
391+
raise ValueError(
392+
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
393+
)
394+
if not file_path:
395+
raise ValueError("File path cannot be empty. Please provide a valid file path")
396+
397+
if module_type == _ModuleType.nn:
398+
raise ValueError(
399+
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
400+
)
401+
elif module_type == _ModuleType.ts:
402+
if output_format == "exported_program":
403+
raise ValueError(
404+
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
405+
)
406+
else:
407+
torch.jit.save(module, file_path)
408+
elif module_type == _ModuleType.ep:
409+
if output_format == "torchscript":
410+
raise ValueError(
411+
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
412+
)
413+
else:
414+
torch.export.save(module, file_path)
415+
elif module_type == _ModuleType.fx:
416+
if inputs is None:
417+
raise ValueError(
418+
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
419+
)
420+
# The module type is torch.fx.GraphModule
421+
if output_format == "torchscript":
422+
module_ts = torch.jit.trace(module, inputs)
423+
torch.jit.save(module_ts, file_path)
424+
else:
425+
if not retrace:
426+
from torch_tensorrt.dynamo._exporter import export
427+
428+
exp_program = export(module, inputs)
429+
torch.export.save(exp_program, file_path)
430+
else:
431+
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
432+
433+
with enable_torchbind_tracing():
434+
exp_program = torch.export.export(
435+
module, tuple(inputs), strict=False
436+
)
437+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
MIN_BLOCK_SIZE,
3131
NUM_AVG_TIMING_ITERS,
3232
OPTIMIZATION_LEVEL,
33-
OUTPUT_FORMAT,
3433
PASS_THROUGH_BUILD_FAILURES,
3534
PRECISION,
3635
REFIT,
@@ -48,7 +47,6 @@
4847
dryrun_stats_display,
4948
parse_non_trt_nodes,
5049
)
51-
from torch_tensorrt.dynamo._exporter import export
5250
from torch_tensorrt.dynamo.conversion import (
5351
CompilationSettings,
5452
UnsupportedOperatorException,
@@ -102,9 +100,8 @@ def compile(
102100
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
103101
dryrun: bool = DRYRUN,
104102
hardware_compatible: bool = HARDWARE_COMPATIBLE,
105-
output_format: str = OUTPUT_FORMAT,
106103
**kwargs: Any,
107-
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
104+
) -> torch.fx.GraphModule:
108105
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
109106
110107
Takes a existing TorchScript module and a set of settings to configure the compiler
@@ -246,14 +243,12 @@ def compile(
246243
"dla_global_dram_size": dla_global_dram_size,
247244
"dryrun": dryrun,
248245
"hardware_compatible": hardware_compatible,
249-
"output_format": output_format,
250246
}
251247

252248
settings = CompilationSettings(**compilation_options)
253249
logger.info("Compilation Settings: %s\n", settings)
254250
trt_gm = compile_module(gm, inputs, settings)
255-
trt_result = export(trt_gm, torch_inputs, output_format)
256-
return trt_result
251+
return trt_gm
257252

258253

259254
def compile_module(

py/torch_tensorrt/dynamo/_defaults.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
OUTPUT_FORMAT = "exported_program"
3029

3130

3231
def default_device() -> Device:

0 commit comments

Comments
 (0)