Skip to content

Commit

Permalink
preserve source positions for assertion rewriting (#12818)
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Sep 15, 2024
1 parent 9515dfa commit e371879
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog/12818.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
assertion rewriting preserves the source ranges of the original instructions.
19 changes: 12 additions & 7 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
return ast.copy_location(ast.Name(name, ast.Load()), expr)

def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
Expand Down Expand Up @@ -975,7 +975,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
if getattr(node, "lineno", None) is None:
ast.copy_location(node, assert_)
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
Expand Down Expand Up @@ -1052,15 +1053,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
return res, pattern % (operand_expl,)

def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
res = self.assign(
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
)
return res, explanation

def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
Expand Down Expand Up @@ -1089,7 +1092,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
arg_expls.append("**" + expl)

expl = "{}({})".format(func_expl, ", ".join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
Expand All @@ -1105,7 +1108,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
res = self.assign(
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s.%s\n}"
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
Expand Down Expand Up @@ -1146,7 +1151,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
Expand Down
212 changes: 208 additions & 4 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

import ast
import dis
import errno
from functools import partial
import glob
import importlib
import inspect
import marshal
import os
from pathlib import Path
Expand Down Expand Up @@ -131,10 +133,212 @@ def test_location_is_set(self) -> None:
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert isinstance(n, (ast.stmt, ast.expr))
assert n.lineno == 3
assert n.col_offset == 0
assert n.end_lineno == 6
assert n.end_col_offset == 3
for location in [
(n.lineno, n.col_offset),
(n.end_lineno, n.end_col_offset),
]:
assert (3, 0) <= location <= (6, 3)

def test_positions_are_preserved(self) -> None:
def preserved(code):
s = textwrap.dedent(code)
locations = []

def loc(msg=None):
frame = inspect.currentframe()
assert frame
frame = frame.f_back
assert frame
frame = frame.f_back
assert frame

offset = frame.f_lasti

instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}

# skip CACHE instructions
while offset not in instructions and offset >= 0:
offset -= 1

instruction = instructions[offset]
if sys.version_info >= (3, 11):
position = instruction.positions
else:
position = instruction.starts_line

locations.append((msg, instruction.opname, position))

globals = {"loc": loc}

m = rewrite(s)
mod = compile(m, "<string>", "exec")
exec(mod, globals, globals)
transformed_locations = locations
locations = []

mod = compile(s, "<string>", "exec")
exec(mod, globals, globals)
original_locations = locations

assert len(original_locations) > 0
assert original_locations == transformed_locations

preserved("""
def f():
loc()
return 8
assert f() in [8]
assert (f()
in
[8])
""")

preserved("""
class T:
def __init__(self):
loc("init")
def __getitem__(self,index):
loc("getitem")
return index
assert T()[5] == 5
assert (T
()
[5]
==
5)
""")

for name, op in [
("pos", "+"),
("neg", "-"),
("invert", "~"),
]:
preserved(f"""
class T:
def __{name}__(self):
loc("{name}")
return "{name}"
assert {op}T() == "{name}"
assert ({op}
T
()
==
"{name}")
""")

for name, op in [
("add", "+"),
("sub", "-"),
("mul", "*"),
("truediv", "/"),
("floordiv", "//"),
("mod", "%"),
("pow", "**"),
("lshift", "<<"),
("rshift", ">>"),
("or", "|"),
("xor", "^"),
("and", "&"),
("matmul", "@"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc("{name}")
return other
def __r{name}__(self,other):
loc("r{name}")
return other
assert T() {op} 2 == 2
assert 2 {op} T() == 2
assert (T
()
{op}
2
==
2)
assert (2
{op}
T
()
==
2)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", "<"),
("le", "<="),
("gt", ">"),
("ge", ">="),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert T() {op} 5
assert (T
()
{op}
5)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", ">"),
("le", ">="),
("gt", "<"),
("ge", "<="),
("contains", "in"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert 5 {op} T()
assert (5
{op}
T
())
""")



preserved(f"""
def func(value):
loc("func")
return value
class T:
def __iter__(self):
loc("iter")
return iter([5])
assert func(*T()) == 5
""")

preserved(f"""
class T:
def __getattr__(self,name):
loc()
return name
assert T().attr == "attr"
""")


def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
Expand Down

0 comments on commit e371879

Please sign in to comment.