diff --git a/AUTHORS b/AUTHORS index 374e6ad9bcc..01cf0047753 100644 --- a/AUTHORS +++ b/AUTHORS @@ -160,6 +160,7 @@ Feng Ma Florian Bruhin Florian Dahlitz Floris Bruynooghe +Frank Hoffmann Fraser Stark Gabriel Landau Gabriel Reis diff --git a/changelog/12818.bugfix.rst b/changelog/12818.bugfix.rst new file mode 100644 index 00000000000..94658a278f5 --- /dev/null +++ b/changelog/12818.bugfix.rst @@ -0,0 +1 @@ +assertion rewriting preserves the source ranges of the original instructions. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index a7a92c0f1fe..c1fc2bf6dae 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name: """Give *expr* a name.""" name = self.variable() self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) - return ast.Name(name, ast.Load()) + return ast.copy_location(ast.Name(name, ast.Load()), expr) def display(self, expr: ast.expr) -> ast.expr: """Call saferepr on the expression.""" @@ -975,7 +975,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: # Fix locations (line numbers/column offsets). for stmt in self.statements: for node in traverse_node(stmt): - ast.copy_location(node, assert_) + if getattr(node, "lineno", None) is None: + ast.copy_location(node, assert_) return self.statements def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: @@ -1052,7 +1053,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]: pattern = UNARY_MAP[unary.op.__class__] operand_res, operand_expl = self.visit(unary.operand) - res = self.assign(ast.UnaryOp(unary.op, operand_res)) + res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary)) return res, pattern % (operand_expl,) def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]: @@ -1060,7 +1061,9 @@ def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]: left_expr, left_expl = self.visit(binop.left) right_expr, right_expl = self.visit(binop.right) explanation = f"({left_expl} {symbol} {right_expl})" - res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) + res = self.assign( + ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop) + ) return res, explanation def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: @@ -1089,7 +1092,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: arg_expls.append("**" + expl) expl = "{}({})".format(func_expl, ", ".join(arg_expls)) - new_call = ast.Call(new_func, new_args, new_kwargs) + new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call) res = self.assign(new_call) res_expl = self.explanation_param(self.display(res)) outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" @@ -1105,7 +1108,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: if not isinstance(attr.ctx, ast.Load): return self.generic_visit(attr) value, value_expl = self.visit(attr.value) - res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) + res = self.assign( + ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr) + ) res_expl = self.explanation_param(self.display(res)) pat = "%s\n{%s = %s.%s\n}" expl = pat % (res_expl, res_expl, value_expl, attr.attr) @@ -1146,7 +1151,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms.append(ast.Constant(sym)) expl = f"{left_expl} {sym} {next_expl}" expls.append(ast.Constant(expl)) - res_expr = ast.Compare(left_res, [op], [next_res]) + res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl # Use pytest.assertion.util._reprcompare if that's available. diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 73c11a1a9d8..21786888604 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -2,10 +2,12 @@ from __future__ import annotations import ast +import dis import errno from functools import partial import glob import importlib +import inspect import marshal import os from pathlib import Path @@ -131,10 +133,209 @@ def test_location_is_set(self) -> None: continue for n in [node, *ast.iter_child_nodes(node)]: assert isinstance(n, (ast.stmt, ast.expr)) - assert n.lineno == 3 - assert n.col_offset == 0 - assert n.end_lineno == 6 - assert n.end_col_offset == 3 + for location in [ + (n.lineno, n.col_offset), + (n.end_lineno, n.end_col_offset), + ]: + assert (3, 0) <= location <= (6, 3) + + def test_positions_are_preserved(self) -> None: + def preserved(code): + s = textwrap.dedent(code) + locations = [] + + def loc(msg=None): + frame = inspect.currentframe() + assert frame + frame = frame.f_back + assert frame + frame = frame.f_back + assert frame + + offset = frame.f_lasti + + instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)} + + # skip CACHE instructions + while offset not in instructions and offset >= 0: + offset -= 1 + + instruction = instructions[offset] + if sys.version_info >= (3, 11): + position = instruction.positions + else: + position = instruction.starts_line + + locations.append((msg, instruction.opname, position)) + + globals = {"loc": loc} + + m = rewrite(s) + mod = compile(m, "", "exec") + exec(mod, globals, globals) + transformed_locations = locations + locations = [] + + mod = compile(s, "", "exec") + exec(mod, globals, globals) + original_locations = locations + + assert len(original_locations) > 0 + assert original_locations == transformed_locations + + preserved(""" + def f(): + loc() + return 8 + + assert f() in [8] + assert (f() + in + [8]) + """) + + preserved(""" + class T: + def __init__(self): + loc("init") + def __getitem__(self,index): + loc("getitem") + return index + + assert T()[5] == 5 + assert (T + () + [5] + == + 5) + """) + + for name, op in [ + ("pos", "+"), + ("neg", "-"), + ("invert", "~"), + ]: + preserved(f""" + class T: + def __{name}__(self): + loc("{name}") + return "{name}" + + assert {op}T() == "{name}" + assert ({op} + T + () + == + "{name}") + """) + + for name, op in [ + ("add", "+"), + ("sub", "-"), + ("mul", "*"), + ("truediv", "/"), + ("floordiv", "//"), + ("mod", "%"), + ("pow", "**"), + ("lshift", "<<"), + ("rshift", ">>"), + ("or", "|"), + ("xor", "^"), + ("and", "&"), + ("matmul", "@"), + ]: + preserved(f""" + class T: + def __{name}__(self,other): + loc("{name}") + return other + + def __r{name}__(self,other): + loc("r{name}") + return other + + assert T() {op} 2 == 2 + assert 2 {op} T() == 2 + + assert (T + () + {op} + 2 + == + 2) + + assert (2 + {op} + T + () + == + 2) + """) + + for name, op in [ + ("eq", "=="), + ("ne", "!="), + ("lt", "<"), + ("le", "<="), + ("gt", ">"), + ("ge", ">="), + ]: + preserved(f""" + class T: + def __{name}__(self,other): + loc() + return True + + assert T() {op} 5 + assert (T + () + {op} + 5) + """) + + for name, op in [ + ("eq", "=="), + ("ne", "!="), + ("lt", ">"), + ("le", ">="), + ("gt", "<"), + ("ge", "<="), + ("contains", "in"), + ]: + preserved(f""" + class T: + def __{name}__(self,other): + loc() + return True + + assert 5 {op} T() + assert (5 + {op} + T + ()) + """) + + preserved(""" + def func(value): + loc("func") + return value + + class T: + def __iter__(self): + loc("iter") + return iter([5]) + + assert func(*T()) == 5 + """) + + preserved(""" + class T: + def __getattr__(self,name): + loc() + return name + + assert T().attr == "attr" + """) def test_dont_rewrite(self) -> None: s = """'PYTEST_DONT_REWRITE'\nassert 14"""