Skip to content

Commit 1abe62f

Browse files
committed
parent f39e89e
author Dheeraj Peri <[email protected]> 1711393059 -0700 committer Dheeraj Peri <[email protected]> 1711393072 -0700 chore: minor updates chore: Fix save failures chore: minor fixes chore: remove duplicate bert test case chore: remove comments chore: add load api chore: minor updates
1 parent f39e89e commit 1abe62f

14 files changed

+318
-270
lines changed

.github/workflows/build-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ jobs:
169169
cd tests/py/dynamo
170170
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
171171
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
172+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/save_load_test_results.xml --ir dynamo models/test_save_load.py
172173
popd
173174
174175
tests-py-torch-compile-be:

core/runtime/TRTEngine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
241241
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
242242
<< std::endl;
243243
}
244-
ss << " }" << std::endl;
244+
ss << " ]" << std::endl;
245245
ss << " Device: " << device_info << std::endl;
246246
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
247247
// clang-format on

docsrc/user_guide/saving_models.rst

+35-36
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
12+
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.
1313

1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18-
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19-
The `output_format` can take the following options
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
19+
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

21-
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22-
* `torchscript` (or) `ts` : This returns a TorchScript module
23-
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
21+
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
22+
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
2423

25-
a) Torchscript
24+
a) ExportedProgram
2625
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2726

28-
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
27+
Here's an example usage
2928

3029
.. code-block:: python
3130
@@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`
3433
3534
model = MyModel().eval().cuda()
3635
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
37-
# trt_ts is a torch.jit.ScriptModule object
38-
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39-
torch.jit.save(trt_ts, "trt_model.ts")
36+
# trt_ep is a torch.fx.GraphModule object
37+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
38+
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)
4039
4140
# Later, you can load it and run inference
42-
model = torch.jit.load("trt_model.ts").cuda()
41+
model = torch.export.load("trt.ep").module()
4342
model(*inputs)
4443
45-
b) ExportedProgram
44+
b) Torchscript
4645
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4746

48-
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
49-
5047
.. code-block:: python
5148
5249
import torch
5350
import torch_tensorrt
5451
5552
model = MyModel().eval().cuda()
5653
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
57-
# trt_ep is a torch.export.ExportedProgram object
58-
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59-
torch.export.save(trt_ep, "trt_model.ep")
54+
# trt_gm is a torch.fx.GraphModule object
55+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
56+
torchtrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
6057
6158
# Later, you can load it and run inference
62-
model = torch.export.load("trt_model.ep")
59+
model = torch.jit.load("trt.ts").cuda()
6360
model(*inputs)
6461
65-
c) GraphModule
66-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67-
68-
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69-
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70-
exported into `ExportedProgram` objects
71-
72-
.. code-block:: python
73-
74-
import torch
75-
import torch_tensorrt
76-
77-
model = MyModel().eval().cuda()
78-
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79-
# trt_gm is a torch.fx.GraphModule object
80-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
8162
8263
Torchscript IR
8364
-------------
@@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
9980
model = torch.jit.load("trt_model.ts").cuda()
10081
model(*inputs)
10182
83+
84+
Loading the models
85+
--------------------
86+
87+
We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
88+
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.
89+
90+
Here's an example usage
91+
92+
.. code-block:: python
93+
94+
import torch
95+
import torch_tensorrt
96+
97+
# file_path can be trt.ep or trt.es file obtained via saving the model (refer to the above section)
98+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99+
model = torch_tensorrt.load(<file_path>).module()
100+
model(*inputs)

py/torch_tensorrt/_compile.py

+92-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.fx
9+
import torch_tensorrt.dynamo
910
import torch_tensorrt.ts
1011
from torch_tensorrt._enums import dtype
1112
from torch_tensorrt._Input import Input
@@ -26,10 +27,7 @@
2627

2728
logger = logging.getLogger(__name__)
2829

29-
__all__ = [
30-
"compile",
31-
"convert_method_to_trt_engine",
32-
]
30+
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
3331

3432

3533
def _non_fx_input_interface(
@@ -332,3 +330,93 @@ def convert_method_to_trt_engine(
332330
)
333331
else:
334332
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
333+
334+
335+
def load(file_path: str = "") -> Any:
336+
"""
337+
Load either a Torchscript model or ExportedProgram. Autodetect the type using
338+
try, except
339+
"""
340+
try:
341+
ts_module = torch.jit.load(file_path)
342+
return ts_module
343+
except Exception:
344+
pass
345+
346+
try:
347+
exp_program = torch.export.load(file_path)
348+
return exp_program
349+
except Exception:
350+
raise ValueError(
351+
"The file doesn't correspond to a Torchscript module or ExportedProgram. Please verify the file path."
352+
)
353+
354+
355+
def save(
356+
module: Any,
357+
file_path: str = "",
358+
*,
359+
output_format: str = "exported_program",
360+
inputs: Optional[Sequence[torch.Tensor]] = None,
361+
retrace: bool = False,
362+
) -> None:
363+
"""
364+
Save the model to disk in the specified output format.
365+
Arguments:
366+
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
367+
inputs (torch.Tensor): Torch input tensors
368+
"""
369+
module_type = _parse_module_type(module)
370+
accepted_formats = {"exported_program", "torchscript"}
371+
if inputs and not all(isinstance(input, torch.Tensor) for input in inputs):
372+
raise ValueError(
373+
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
374+
)
375+
if output_format not in accepted_formats:
376+
raise ValueError(
377+
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
378+
)
379+
if not file_path:
380+
raise ValueError("File path cannot be empty. Please provide a valid file path")
381+
382+
if module_type == _ModuleType.nn:
383+
raise ValueError(
384+
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
385+
)
386+
elif module_type == _ModuleType.ts:
387+
if output_format == "exported_program":
388+
raise ValueError(
389+
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
390+
)
391+
else:
392+
torch.jit.save(module, file_path)
393+
elif module_type == _ModuleType.ep:
394+
if output_format == "torchscript":
395+
raise ValueError(
396+
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
397+
)
398+
else:
399+
torch.export.save(module, file_path)
400+
elif module_type == _ModuleType.fx:
401+
if not inputs:
402+
raise ValueError(
403+
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
404+
)
405+
# The module type is torch.fx.GraphModule
406+
if output_format == "torchscript":
407+
module_ts = torch.jit.trace(module, inputs)
408+
torch.jit.save(module_ts, file_path)
409+
else:
410+
if not retrace:
411+
from torch_tensorrt.dynamo._exporter import export
412+
413+
exp_program = export(module, inputs)
414+
torch.export.save(exp_program, file_path)
415+
else:
416+
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
417+
418+
with enable_torchbind_tracing():
419+
exp_program = torch.export.export(
420+
module, tuple(inputs), strict=False
421+
)
422+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
MIN_BLOCK_SIZE,
3131
NUM_AVG_TIMING_ITERS,
3232
OPTIMIZATION_LEVEL,
33-
OUTPUT_FORMAT,
3433
PASS_THROUGH_BUILD_FAILURES,
3534
PRECISION,
3635
REFIT,
@@ -48,7 +47,6 @@
4847
dryrun_stats_display,
4948
parse_non_trt_nodes,
5049
)
51-
from torch_tensorrt.dynamo._exporter import export
5250
from torch_tensorrt.dynamo.conversion import (
5351
CompilationSettings,
5452
UnsupportedOperatorException,
@@ -102,9 +100,8 @@ def compile(
102100
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
103101
dryrun: bool = DRYRUN,
104102
hardware_compatible: bool = HARDWARE_COMPATIBLE,
105-
output_format: str = OUTPUT_FORMAT,
106103
**kwargs: Any,
107-
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
104+
) -> torch.fx.GraphModule:
108105
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
109106
110107
Takes a existing TorchScript module and a set of settings to configure the compiler
@@ -246,14 +243,12 @@ def compile(
246243
"dla_global_dram_size": dla_global_dram_size,
247244
"dryrun": dryrun,
248245
"hardware_compatible": hardware_compatible,
249-
"output_format": output_format,
250246
}
251247

252248
settings = CompilationSettings(**compilation_options)
253249
logger.info("Compilation Settings: %s\n", settings)
254250
trt_gm = compile_module(gm, inputs, settings)
255-
trt_result = export(trt_gm, torch_inputs, output_format)
256-
return trt_result
251+
return trt_gm
257252

258253

259254
def compile_module(

py/torch_tensorrt/dynamo/_defaults.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
OUTPUT_FORMAT = "exported_program"
3029

3130

3231
def default_device() -> Device:

0 commit comments

Comments
 (0)