Skip to content

Commit

Permalink
preserve source positions for assertion rewriting (#12818)
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Sep 15, 2024
1 parent 9515dfa commit 5412956
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 11 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Feng Ma
Florian Bruhin
Florian Dahlitz
Floris Bruynooghe
Frank Hoffmann
Fraser Stark
Gabriel Landau
Gabriel Reis
Expand Down
1 change: 1 addition & 0 deletions changelog/12818.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
assertion rewriting preserves the source ranges of the original instructions.
19 changes: 12 additions & 7 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
return ast.copy_location(ast.Name(name, ast.Load()), expr)

def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
Expand Down Expand Up @@ -975,7 +975,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
if getattr(node, "lineno", None) is None:
ast.copy_location(node, assert_)
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
Expand Down Expand Up @@ -1052,15 +1053,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
return res, pattern % (operand_expl,)

def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
res = self.assign(
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
)
return res, explanation

def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
Expand Down Expand Up @@ -1089,7 +1092,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
arg_expls.append("**" + expl)

expl = "{}({})".format(func_expl, ", ".join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
Expand All @@ -1105,7 +1108,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
res = self.assign(
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s.%s\n}"
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
Expand Down Expand Up @@ -1146,7 +1151,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
Expand Down
209 changes: 205 additions & 4 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

import ast
import dis
import errno
from functools import partial
import glob
import importlib
import inspect
import marshal
import os
from pathlib import Path
Expand Down Expand Up @@ -131,10 +133,209 @@ def test_location_is_set(self) -> None:
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert isinstance(n, (ast.stmt, ast.expr))
assert n.lineno == 3
assert n.col_offset == 0
assert n.end_lineno == 6
assert n.end_col_offset == 3
for location in [
(n.lineno, n.col_offset),
(n.end_lineno, n.end_col_offset),
]:
assert (3, 0) <= location <= (6, 3)

def test_positions_are_preserved(self) -> None:
def preserved(code):
s = textwrap.dedent(code)
locations = []

def loc(msg=None):
frame = inspect.currentframe()
assert frame
frame = frame.f_back
assert frame
frame = frame.f_back
assert frame

offset = frame.f_lasti

instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}

# skip CACHE instructions
while offset not in instructions and offset >= 0:
offset -= 1

Check warning on line 161 in testing/test_assertrewrite.py

View check run for this annotation

Codecov / codecov/patch

testing/test_assertrewrite.py#L161

Added line #L161 was not covered by tests

instruction = instructions[offset]
if sys.version_info >= (3, 11):
position = instruction.positions

Check warning on line 165 in testing/test_assertrewrite.py

View check run for this annotation

Codecov / codecov/patch

testing/test_assertrewrite.py#L165

Added line #L165 was not covered by tests
else:
position = instruction.starts_line

locations.append((msg, instruction.opname, position))

globals = {"loc": loc}

m = rewrite(s)
mod = compile(m, "<string>", "exec")
exec(mod, globals, globals)
transformed_locations = locations
locations = []

mod = compile(s, "<string>", "exec")
exec(mod, globals, globals)
original_locations = locations

assert len(original_locations) > 0
assert original_locations == transformed_locations

preserved("""
def f():
loc()
return 8
assert f() in [8]
assert (f()
in
[8])
""")

preserved("""
class T:
def __init__(self):
loc("init")
def __getitem__(self,index):
loc("getitem")
return index
assert T()[5] == 5
assert (T
()
[5]
==
5)
""")

for name, op in [
("pos", "+"),
("neg", "-"),
("invert", "~"),
]:
preserved(f"""
class T:
def __{name}__(self):
loc("{name}")
return "{name}"
assert {op}T() == "{name}"
assert ({op}
T
()
==
"{name}")
""")

for name, op in [
("add", "+"),
("sub", "-"),
("mul", "*"),
("truediv", "/"),
("floordiv", "//"),
("mod", "%"),
("pow", "**"),
("lshift", "<<"),
("rshift", ">>"),
("or", "|"),
("xor", "^"),
("and", "&"),
("matmul", "@"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc("{name}")
return other
def __r{name}__(self,other):
loc("r{name}")
return other
assert T() {op} 2 == 2
assert 2 {op} T() == 2
assert (T
()
{op}
2
==
2)
assert (2
{op}
T
()
==
2)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", "<"),
("le", "<="),
("gt", ">"),
("ge", ">="),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert T() {op} 5
assert (T
()
{op}
5)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", ">"),
("le", ">="),
("gt", "<"),
("ge", "<="),
("contains", "in"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert 5 {op} T()
assert (5
{op}
T
())
""")

preserved("""
def func(value):
loc("func")
return value
class T:
def __iter__(self):
loc("iter")
return iter([5])
assert func(*T()) == 5
""")

preserved("""
class T:
def __getattr__(self,name):
loc()
return name
assert T().attr == "attr"
""")

def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
Expand Down

0 comments on commit 5412956

Please sign in to comment.