Skip to content

Commit dabc04f

Browse files
committed
fix: FakeTensors appearing in get_attr calls
- Register all constants as model parameters, which do not get fake-ified by the active FakeTensor context - Buffers and other constant registrations can be fake-ified, which is problematic for TRT tracing
1 parent fb07513 commit dabc04f

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TRTInterpreterResult,
1313
)
1414
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
15-
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
15+
from torch_tensorrt.dynamo.utils import get_torch_inputs
1616

1717

1818
def interpret_module_to_result(
@@ -29,7 +29,6 @@ def interpret_module_to_result(
2929
TRTInterpreterResult
3030
"""
3131
torch_inputs = get_torch_inputs(inputs, settings.device)
32-
module.to(to_torch_device(settings.device))
3332
module_outputs = module(*torch_inputs)
3433

3534
if not isinstance(module_outputs, (list, tuple)):

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
# Modify import location of utilities based on Torch version
1313
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
14-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
14+
from torch._inductor.freezing import ConstantFolder
1515
else:
16-
from torch._inductor.constant_folding import (
17-
ConstantFolder,
18-
replace_node_with_constant,
19-
)
16+
from torch._inductor.constant_folding import ConstantFolder
2017

2118
logger = logging.getLogger(__name__)
2219

@@ -36,7 +33,7 @@ def constant_fold(
3633
cf.run()
3734

3835
for node, constant in cf.node_replacements.items():
39-
replace_node_with_constant(gm, node, constant)
36+
replace_node_with_constant(gm, node, torch.nn.Parameter(constant.cuda()))
4037

4138
erased_params = []
4239
for node in gm.graph.nodes:
@@ -60,6 +57,35 @@ def constant_fold(
6057
return gm
6158

6259

60+
def replace_node_with_constant(
61+
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
62+
) -> None:
63+
g = gm.graph
64+
65+
if not hasattr(gm, "_frozen_param_count"):
66+
gm._frozen_param_count = 0
67+
68+
i = gm._frozen_param_count
69+
70+
while True:
71+
qualname = f"_frozen_param{i}"
72+
if not hasattr(gm, qualname):
73+
break
74+
i += 1
75+
76+
gm._frozen_param_count = i + 1
77+
78+
with g.inserting_before(node):
79+
new_input_node = g.create_node("get_attr", qualname, (), {})
80+
node.replace_all_uses_with(new_input_node)
81+
new_input_node.meta.update(node.meta)
82+
g.erase_node(node)
83+
84+
# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
85+
gm.register_parameter(qualname, constant)
86+
setattr(gm, qualname, constant)
87+
88+
6389
# TODO: Delete this class when the following code is fixed in nightly:
6490
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6591
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]

0 commit comments

Comments
 (0)