Skip to content

Commit 243222e

Browse files
authored
Fix slicing list returning wrong result (#222)
* Correct ListVariable source * Fix lint
1 parent 69703c2 commit 243222e

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

tests/test_misc.py

+12
Original file line numberDiff line numberDiff line change
@@ -1157,3 +1157,15 @@ def fn():
11571157
inst = dis.get_instructions(fn)
11581158
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
11591159
self.assertTrue(result[1] == fn.__code__.co_lnotab)
1160+
1161+
def test_python_slice(self):
1162+
def fn(input):
1163+
y = 0
1164+
for i, x in enumerate(input[2:], 1):
1165+
y = y + x
1166+
return y
1167+
1168+
cnts = torchdynamo.testing.CompileCounter()
1169+
with torchdynamo.optimize(cnts):
1170+
z = fn([1, 2, 3, 5])
1171+
self.assertEqual(z, 8)

torchdynamo/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def rot_n_helper(n):
319319
def is_safe_constant(v):
320320
if istype(v, (tuple, frozenset)):
321321
return all(map(is_safe_constant, v))
322-
return istype(v, (types.CodeType, int, float, bool, str, bytes, type(None)))
322+
return istype(v, (types.CodeType, int, float, bool, str, bytes, type(None), slice))
323323

324324

325325
def check_constant_args(args, kwargs):

torchdynamo/variables/lists.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .. import variables
77
from ..bytecode_transformation import create_instruction
88
from ..exc import unimplemented
9+
from ..source import GetItemSource
910
from ..utils import namedtuple_fields
1011
from .base import MutableLocal
1112
from .base import VariableTracker
@@ -40,9 +41,11 @@ def as_proxy(self):
4041
def getitem_const(self, arg: VariableTracker):
4142
index = arg.as_python_constant()
4243
if isinstance(index, slice):
43-
return self.clone(items=self.items[index], mutable_local=None).add_options(
44-
arg, self
45-
)
44+
return self.clone(
45+
items=self.items[index],
46+
source=GetItemSource(self.source, index),
47+
mutable_local=None,
48+
).add_options(arg, self)
4649
else:
4750
assert isinstance(index, int)
4851
return self.items[index].add_options(arg, self)

0 commit comments

Comments
 (0)