Skip to content

Commit c8ffcde

Browse files
committed
rebase
1 parent 9cf3356 commit c8ffcde

File tree

4 files changed

+53
-39
lines changed

4 files changed

+53
-39
lines changed

Diff for: docsrc/py_api/dynamo.rst

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Functions
2222

2323
.. autofunction:: export
2424

25+
.. autofunction:: convert_module_to_trt_engine
26+
2527

2628

2729
Classes

Diff for: py/torch_tensorrt/_compile.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import logging
45
from enum import Enum
56
from typing import Any, Callable, List, Optional, Sequence, Set
@@ -240,8 +241,6 @@ def compile(
240241
return compiled_fx_module
241242
elif target_ir == _IRType.dynamo:
242243
# Prepare torch and torchtrt inputs
243-
import collections.abc
244-
245244
from torch_tensorrt.dynamo.utils import prepare_inputs
246245

247246
if not isinstance(input_list, collections.abc.Sequence):
@@ -345,10 +344,19 @@ def convert_method_to_trt_engine(
345344
"convert_method_to_trt_engine call is not supported for ir=fx"
346345
)
347346
elif target_ir == _IRType.dynamo:
347+
# Prepare torch and torchtrt inputs
348+
from torch_tensorrt.dynamo.utils import prepare_inputs
349+
350+
if not isinstance(inputs, collections.abc.Sequence):
351+
inputs = [inputs]
352+
353+
# Export the module
354+
torchtrt_inputs = prepare_inputs(inputs)
355+
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
356+
348357
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
349-
module,
358+
exp_program,
350359
inputs=inputs,
351-
method_name=method_name,
352360
enabled_precisions=enabled_precisions_set,
353361
**kwargs,
354362
)

Diff for: py/torch_tensorrt/dynamo/_compiler.py

+33-31
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ def compile_module(
416416

417417

418418
def convert_module_to_trt_engine(
419-
module: torch.fx.GraphModule,
420-
method_name: str = "forward",
419+
exported_program: ExportedProgram,
421420
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
422421
enabled_precisions: (
423422
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -435,6 +434,7 @@ def convert_module_to_trt_engine(
435434
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
436435
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
437436
device: Device = Device._current_device(),
437+
<<<<<<< HEAD
438438
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
439439
disable_tf32: bool = _defaults.DISABLE_TF32,
440440
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
@@ -446,16 +446,27 @@ def convert_module_to_trt_engine(
446446
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
447447
calibrator: object = None,
448448
allow_shape_tensors: bool = False,
449+
=======
450+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
451+
disable_tf32: bool = DISABLE_TF32,
452+
sparse_weights: bool = SPARSE_WEIGHTS,
453+
refit: bool = REFIT,
454+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
455+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
456+
dla_sram_size: int = DLA_SRAM_SIZE,
457+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
458+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
459+
>>>>>>> 51ee9efc56e3ad13d4c76099c2e332f778d0c976
449460
) -> bytes:
450-
"""Convert a GraphModule module method to a serialized TensorRT engine
461+
"""Convert an ExportedProgram to a serialized TensorRT engine
451462
452-
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
463+
Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings
453464
454465
Arguments:
455-
module (torch.fx.GraphModule): Source module
466+
exported_program (torch.export.ExportedProgram): Source module
456467
457468
Keyword Args:
458-
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
469+
inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
459470
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
460471
to select device type. ::
461472
@@ -470,30 +481,11 @@ def convert_module_to_trt_engine(
470481
), # Dynamic input shape for input #2
471482
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
472483
]
473-
474-
method_name (str): Name of method to convert
475-
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
476-
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
477-
478-
input_signature=([
479-
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
480-
torch_tensorrt.Input(
481-
min_shape=(1, 224, 224, 3),
482-
opt_shape=(1, 512, 512, 3),
483-
max_shape=(1, 1024, 1024, 3),
484-
dtype=torch.int32
485-
format=torch.channel_last
486-
), # Dynamic input shape for input #2
487-
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
488-
489-
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
490-
491-
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
492-
484+
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
493485
debug (bool): Whether to print out verbose debugging information
494486
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
495487
min_block_size (int): Minimum number of operators per TRT-Engine Block
496-
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
488+
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
497489
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
498490
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
499491
version_compatible (bool): Provide version forward-compatibility for engine plan files
@@ -517,8 +509,6 @@ def convert_module_to_trt_engine(
517509
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
518510
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
519511
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
520-
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
521-
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
522512
523513
Returns:
524514
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
@@ -560,13 +550,25 @@ def convert_module_to_trt_engine(
560550
"dla_global_dram_size": dla_global_dram_size,
561551
}
562552

553+
# Decompose the exported program
554+
exported_program = exported_program.run_decompositions(
555+
get_decompositions(enable_experimental_decompositions)
556+
)
557+
gm = exported_program.module()
558+
logger.debug("Input graph: " + str(gm.graph))
559+
560+
# Apply lowering on the graph module
561+
torch_inputs = get_torch_inputs(input_list, device)
562+
gm = apply_lowering_passes(gm, torch_inputs)
563+
logger.debug("Lowered Input graph: " + str(gm.graph))
564+
563565
settings = CompilationSettings(**compilation_options)
564566
logger.info("Compilation Settings: %s\n", settings)
565567
try:
566-
interpreter_result = interpret_module_to_result(module, input_list, settings)
568+
interpreter_result = interpret_module_to_result(gm, input_list, settings)
567569
except UnsupportedOperatorException:
568570
logger.error(
569-
f"Conversion of module {module} not currently fully supported or convertible!",
571+
f"Conversion of module {gm} not currently fully supported or convertible!",
570572
exc_info=True,
571573
)
572574
except Exception as e:

Diff for: tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py renamed to tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
88

99

10-
class TestConvertMethodToTrtEngine(unittest.TestCase):
10+
class TestConvertModuleToTrtEngine(unittest.TestCase):
1111
def test_convert_module(self):
1212
class Test(torch.nn.Module):
1313
def forward(self, a, b):
@@ -18,19 +18,21 @@ def forward(self, a, b):
1818

1919
# Create a model
2020
model = Test()
21-
symbolic_traced_gm = torch.fx.symbolic_trace(model)
21+
exp_program = torch.export.export(model, (input_data_0, input_data_1))
2222

2323
# Convert to TensorRT engine
2424
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
25-
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
25+
exp_program, inputs=(input_data_0, input_data_1)
2626
)
2727

2828
# Deserialize the TensorRT engine
2929
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
3030
engine = runtime.deserialize_cuda_engine(trt_engine_str)
3131

3232
# Inference on TRT Engine
33-
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
33+
py_trt_module = PythonTorchTensorRTModule(
34+
engine, ["arg0_1", "arg1_1"], ["output0"]
35+
)
3436
trt_output = py_trt_module(input_data_0, input_data_1).cpu()
3537

3638
# Inference on PyTorch model

0 commit comments

Comments
 (0)