Skip to content

Commit 533798e

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Enforce some invariants on ConstantVariable.create (pytorch#140984)
This addresses pytorch#140745 (comment). Pull Request resolved: pytorch#140984 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141504
1 parent 3141e03 commit 533798e

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

torch/_dynamo/variables/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .. import variables
88
from ..current_scope_id import current_scope_id
99
from ..exc import unimplemented
10+
from ..guards import GuardBuilder, install_guard
1011
from ..source import AttrSource, Source
1112
from ..utils import istype
1213

@@ -328,6 +329,8 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
328329
if not variables.ConstantVariable.is_literal(value):
329330
raise NotImplementedError
330331
source = self.source and AttrSource(self.source, name)
332+
if source:
333+
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
331334
return variables.ConstantVariable.create(value, source=source)
332335

333336
def is_proxy(self):

torch/_dynamo/variables/constant.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from .. import variables
1010
from ..exc import unimplemented, UserError, UserErrorType
11-
from ..guards import GuardBuilder, install_guard
1211
from ..utils import common_constant_types, istype, np
1312
from .base import typestr, VariableTracker
1413

@@ -24,6 +23,9 @@ def create(value, **kwargs) -> VariableTracker:
2423
Create a `ConstantVariable` based on the given value, and supports
2524
automatic routing for collection types like `tuple` (in which case we'd
2625
create `ConstantVariable` for the leaf items).
26+
27+
NOTE: the caller must install the proper guards if needed; most often
28+
the guard will be `CONSTANT_MATCH`.
2729
"""
2830
source = kwargs.get("source", None)
2931

@@ -38,8 +40,6 @@ def create(value, **kwargs) -> VariableTracker:
3840
items = []
3941
for i, x in enumerate(value):
4042
item_source = GetItemSource(source, i) if source else None
41-
if item_source:
42-
install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
4343
items.append(
4444
ConstantVariable.create(
4545
x,

torch/_dynamo/variables/misc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
145145
return GetAttrVariable(self, name)
146146
if source:
147147
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
148-
return variables.ConstantVariable.create(value, source=source)
149-
return variables.ConstantVariable.create(value)
148+
return variables.ConstantVariable.create(value, source=source)
150149

151150
def call_method(
152151
self,

0 commit comments

Comments
 (0)