11
11
12
12
# Modify import location of utilities based on Torch version
13
13
if version .parse (sanitized_torch_version ()) < version .parse ("2.1.1" ):
14
- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
14
+ from torch ._inductor .freezing import ConstantFolder
15
15
else :
16
- from torch ._inductor .constant_folding import (
17
- ConstantFolder ,
18
- replace_node_with_constant ,
19
- )
16
+ from torch ._inductor .constant_folding import ConstantFolder
20
17
21
18
logger = logging .getLogger (__name__ )
22
19
@@ -36,7 +33,7 @@ def constant_fold(
36
33
cf .run ()
37
34
38
35
for node , constant in cf .node_replacements .items ():
39
- replace_node_with_constant (gm , node , constant )
36
+ replace_node_with_constant (gm , node , torch . nn . Parameter ( constant . cuda ()) )
40
37
41
38
erased_params = []
42
39
for node in gm .graph .nodes :
@@ -60,6 +57,35 @@ def constant_fold(
60
57
return gm
61
58
62
59
60
+ def replace_node_with_constant (
61
+ gm : torch .fx .GraphModule , node : torch .fx .Node , constant : torch .Tensor
62
+ ) -> None :
63
+ g = gm .graph
64
+
65
+ if not hasattr (gm , "_frozen_param_count" ):
66
+ gm ._frozen_param_count = 0
67
+
68
+ i = gm ._frozen_param_count
69
+
70
+ while True :
71
+ qualname = f"_frozen_param{ i } "
72
+ if not hasattr (gm , qualname ):
73
+ break
74
+ i += 1
75
+
76
+ gm ._frozen_param_count = i + 1
77
+
78
+ with g .inserting_before (node ):
79
+ new_input_node = g .create_node ("get_attr" , qualname , (), {})
80
+ node .replace_all_uses_with (new_input_node )
81
+ new_input_node .meta .update (node .meta )
82
+ g .erase_node (node )
83
+
84
+ # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
85
+ gm .register_parameter (qualname , constant )
86
+ setattr (gm , qualname , constant )
87
+
88
+
63
89
# TODO: Delete this class when the following code is fixed in nightly:
64
90
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
65
91
class _TorchTensorRTConstantFolder (ConstantFolder ): # type: ignore[misc]
0 commit comments