Skip to content

❓ [Question] dynamo conversion failing w/ TRTInterpreter #3124

Open
@patrick-botco

Description

@patrick-botco

❓ Question

im able to torch.export and generate an ExportedProgram with no issues for my model. upon compiling with torch_tensorrt...

ep = torch.export.load("...")
example_inputs = ep.example_inputs[0]
model = ep.module().to("cuda")

compile_spec = {
    "ir": "torch_compile",
    "inputs": example_inputs,
    "enabled_precisions": enabled_precisions,
    "workspace_size": workspace_size,
    "min_block_size": min_block_size,
    "torch_executed_ops": {},
    "sparse_weights": True,
}

optimized_model = torch_tensorrt.compile(model, **compile_spec)

... i run into this error:

ERROR:torch_tensorrt [TensorRT Conversion Context]:INetworkDefinition::addConstant: Error Code 3: API Usage Error (Parameter check failed, condition: !weights.values == !weights.count. )
Traceback (most recent call last):
...
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 479, in run
    self._construct_trt_network_def()
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 325, in _construct_trt_network_def
    super().run()
  File ".../lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 529, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File ".../lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 638, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 242, in aten_ops_cat
    return impl.cat.cat(
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/cat.py", line 31, in cat
    each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 384, in get_trt_tensor
    return create_constant(ctx, input_val, name, dtype, min_rank)
  File ".../lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 349, in create_constant
    constant.name = name
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
AttributeError: 'NoneType' object has no attribute 'name'

im currently able to cleanly generate an ExportedProgram via torch.export, and outputs from the trace match the original PyTorch model. in particular, its unclear to me why !weights.values == !weights.count would be an API Usage Error, and the discrepancy between torch.compile and how torch_tensorrt interprets / performs the op conversion (torch.compile on the ExportedProgram module works fine)

What you have already tried

i've narrowed the issue down to a single module that does positional encoding. the output of this module is then concat'd with another tensor, which is the error above. without this module, everything works as expected, and i'm able to see about a 5x speedup.

the only unique thing about this module is that it has a buffer and some in-place operations; however, i've dumped and manually inspected the fx Graph and the trace looks correct (buffer lifted as a constant input). other things ive done are: re-writing the forward so that they are no in-place operations to make graph capture easier.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 2.4
  • CPU Architecture: aarch64
  • OS (e.g., Linux): Ubuntu
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): modified bazel build rules + install
  • Are you using local sources or building from archives: local build from source
  • Python version: 3.10
  • CUDA version: 12.4
  • GPU models and configuration: Ampere (Jetson Nano, JetPack 6.0)
  • Any other relevant information: i compiled torch_tensorrt on HEAD of main as of last Friday (8/23)

Additional context

cc @narendasan not sure if you have any insight here. thanks!

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions