Skip to content

Commit f9cce85

Browse files
committed
preserve source positions for assertion rewriting
1 parent 9515dfa commit f9cce85

File tree

2 files changed

+195
-10
lines changed

2 files changed

+195
-10
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
975975
# Fix locations (line numbers/column offsets).
976976
for stmt in self.statements:
977977
for node in traverse_node(stmt):
978-
ast.copy_location(node, assert_)
978+
if getattr(node, "lineno", None) is None:
979+
ast.copy_location(node, assert_)
979980
return self.statements
980981

981982
def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
@@ -1052,15 +1053,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
10521053
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
10531054
pattern = UNARY_MAP[unary.op.__class__]
10541055
operand_res, operand_expl = self.visit(unary.operand)
1055-
res = self.assign(ast.UnaryOp(unary.op, operand_res))
1056+
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
10561057
return res, pattern % (operand_expl,)
10571058

10581059
def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
10591060
symbol = BINOP_MAP[binop.op.__class__]
10601061
left_expr, left_expl = self.visit(binop.left)
10611062
right_expr, right_expl = self.visit(binop.right)
10621063
explanation = f"({left_expl} {symbol} {right_expl})"
1063-
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
1064+
res = self.assign(
1065+
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
1066+
)
10641067
return res, explanation
10651068

10661069
def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
@@ -1089,7 +1092,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
10891092
arg_expls.append("**" + expl)
10901093

10911094
expl = "{}({})".format(func_expl, ", ".join(arg_expls))
1092-
new_call = ast.Call(new_func, new_args, new_kwargs)
1095+
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
10931096
res = self.assign(new_call)
10941097
res_expl = self.explanation_param(self.display(res))
10951098
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
@@ -1105,7 +1108,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11051108
if not isinstance(attr.ctx, ast.Load):
11061109
return self.generic_visit(attr)
11071110
value, value_expl = self.visit(attr.value)
1108-
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
1111+
res = self.assign(
1112+
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
1113+
)
11091114
res_expl = self.explanation_param(self.display(res))
11101115
pat = "%s\n{%s = %s.%s\n}"
11111116
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
@@ -1146,7 +1151,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11461151
syms.append(ast.Constant(sym))
11471152
expl = f"{left_expl} {sym} {next_expl}"
11481153
expls.append(ast.Constant(expl))
1149-
res_expr = ast.Compare(left_res, [op], [next_res])
1154+
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
11501155
self.statements.append(ast.Assign([store_names[i]], res_expr))
11511156
left_res, left_expl = next_res, next_expl
11521157
# Use pytest.assertion.util._reprcompare if that's available.

testing/test_assertrewrite.py

Lines changed: 184 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from __future__ import annotations
33

44
import ast
5+
import dis
56
import errno
67
from functools import partial
78
import glob
89
import importlib
10+
import inspect
911
import marshal
1012
import os
1113
from pathlib import Path
@@ -131,10 +133,188 @@ def test_location_is_set(self) -> None:
131133
continue
132134
for n in [node, *ast.iter_child_nodes(node)]:
133135
assert isinstance(n, (ast.stmt, ast.expr))
134-
assert n.lineno == 3
135-
assert n.col_offset == 0
136-
assert n.end_lineno == 6
137-
assert n.end_col_offset == 3
136+
for location in [
137+
(n.lineno, n.col_offset),
138+
(n.end_lineno, n.end_col_offset),
139+
]:
140+
assert (3, 0) <= location <= (6, 3)
141+
142+
def test_positions_are_preserved(self) -> None:
143+
def preserved(code):
144+
s = textwrap.dedent(code)
145+
locations = []
146+
147+
def loc(msg=None):
148+
frame = inspect.currentframe()
149+
assert frame
150+
frame = frame.f_back
151+
assert frame
152+
frame = frame.f_back
153+
assert frame
154+
155+
offset = frame.f_lasti
156+
157+
instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}
158+
159+
# skip CACHE instructions
160+
while offset not in instructions and offset >= 0:
161+
offset -= 1
162+
163+
instruction = instructions[offset]
164+
if sys.version_info >= (3, 11):
165+
position = instruction.positions
166+
else:
167+
position = instruction.starts_line
168+
169+
locations.append((msg, instruction.opname, position))
170+
171+
globals = {"loc": loc}
172+
173+
m = rewrite(s)
174+
mod = compile(m, "<string>", "exec")
175+
exec(mod, globals, globals)
176+
transformed_locations = locations
177+
locations = []
178+
179+
mod = compile(s, "<string>", "exec")
180+
exec(mod, globals, globals)
181+
original_locations = locations
182+
183+
assert len(original_locations) > 0
184+
assert original_locations == transformed_locations
185+
186+
preserved("""
187+
def f():
188+
loc()
189+
return 8
190+
191+
assert f() in [8]
192+
assert (f()
193+
in
194+
[8])
195+
""")
196+
197+
preserved("""
198+
class T:
199+
def __init__(self):
200+
loc("init")
201+
def __getitem__(self,index):
202+
loc("getitem")
203+
return index
204+
205+
assert T()[5] == 5
206+
assert (T
207+
()
208+
[5]
209+
==
210+
5)
211+
""")
212+
213+
for name, op in [
214+
("pos", "+"),
215+
("neg", "-"),
216+
("invert", "~"),
217+
]:
218+
preserved(f"""
219+
class T:
220+
def __{name}__(self):
221+
loc("{name}")
222+
return "{name}"
223+
224+
assert {op}T() == "{name}"
225+
assert ({op}
226+
T
227+
()
228+
==
229+
"{name}")
230+
""")
231+
232+
for name, op in [
233+
("add", "+"),
234+
("sub", "-"),
235+
("mul", "*"),
236+
("truediv", "/"),
237+
("floordiv", "//"),
238+
("mod", "%"),
239+
("pow", "**"),
240+
("lshift", "<<"),
241+
("rshift", ">>"),
242+
("or", "|"),
243+
("xor", "^"),
244+
("and", "&"),
245+
("matmul", "@"),
246+
]:
247+
preserved(f"""
248+
class T:
249+
def __{name}__(self,other):
250+
loc("{name}")
251+
return other
252+
253+
def __r{name}__(self,other):
254+
loc("r{name}")
255+
return other
256+
257+
assert T() {op} 2 == 2
258+
assert 2 {op} T() == 2
259+
260+
assert (T
261+
()
262+
{op}
263+
2
264+
==
265+
2
266+
)
267+
268+
assert (2
269+
{op}
270+
T
271+
()
272+
==
273+
2)
274+
""")
275+
276+
for name, op in [
277+
("eq", "=="),
278+
("ne", "!="),
279+
("lt", "<"),
280+
("le", "<="),
281+
("gt", ">"),
282+
("ge", ">="),
283+
]:
284+
preserved(f"""
285+
class T:
286+
def __{name}__(self,other):
287+
loc()
288+
return True
289+
290+
assert T() {op} 5
291+
assert (T
292+
()
293+
{op}
294+
5)
295+
""")
296+
297+
for name, op in [
298+
("eq", "=="),
299+
("ne", "!="),
300+
("lt", ">"),
301+
("le", ">="),
302+
("gt", "<"),
303+
("ge", "<="),
304+
("contains", "in"),
305+
]:
306+
preserved(f"""
307+
class T:
308+
def __{name}__(self,other):
309+
loc()
310+
return True
311+
312+
assert 5 {op} T()
313+
assert (5
314+
{op}
315+
T
316+
())
317+
""")
138318

139319
def test_dont_rewrite(self) -> None:
140320
s = """'PYTEST_DONT_REWRITE'\nassert 14"""

0 commit comments

Comments
 (0)