Skip to content

Commit 70566ef

Browse files
committed
Inital commit
1 parent 1a40830 commit 70566ef

10 files changed

+569
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# direnv configuration
132+
.envrc

Readme.md

Whitespace-only changes.

poetry.lock

+180
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

poetry.toml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[virtualenvs]
2+
create = false

pyproject.toml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[tool.poetry]
2+
name = "pytest-diff-selector"
3+
version = "0.1.0"
4+
description = ""
5+
authors = ["Israel Fruchter <[email protected]>"]
6+
packages = [
7+
{ include = "pytest_diff_selector" },
8+
]
9+
10+
[tool.poetry.scripts]
11+
selector = 'pytest_diff_selector.main:run'
12+
13+
[tool.poetry.dependencies]
14+
python = "^3.10"
15+
unidiff = "^0.7.3"
16+
pyan3 = "^1.2.0"
17+
tqdm = "^4.62.3"
18+
rich = "^11.2.0"
19+
20+
[tool.poetry.dev-dependencies]
21+
22+
[tool.poetry.plugins.pytest11]
23+
pytest-diff-selector = "pytest_diff_selector.plugin"
24+
25+
[build-system]
26+
requires = ["poetry-core @ git+https://github.com/python-poetry/poetry-core.git@master"]
27+
build-backend = "poetry.core.masonry.api"

pytest_diff_selector/__init__.py

Whitespace-only changes.

pytest_diff_selector/main.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import os.path
2+
import argparse
3+
import logging
4+
import sys
5+
from subprocess import check_output
6+
from collections import defaultdict
7+
from pathlib import Path
8+
from typing import Dict
9+
10+
from unidiff import PatchSet, PatchedFile, Hunk, LINE_TYPE_ADDED, LINE_TYPE_REMOVED
11+
from pyan.analyzer import CallGraphVisitor, Flavor
12+
from tqdm import tqdm
13+
14+
15+
class CollectionVisitor(CallGraphVisitor):
16+
def __init__(self, *args, **kwargs):
17+
self.filenames_length = len(args[0])
18+
self.progress_bar = tqdm(desc="Analyzing", total=self.filenames_length * 2)
19+
super().__init__(*args, **kwargs)
20+
21+
def process(self):
22+
self.defines_edges = defaultdict(list)
23+
super().process()
24+
25+
def process_one(self, filename):
26+
if self.progress_bar and self.progress_bar.disable:
27+
print(f"scanning: {filename}")
28+
super().process_one(filename)
29+
if self.progress_bar:
30+
self.progress_bar.set_postfix_str(filename)
31+
self.progress_bar.update()
32+
33+
def postprocess(self):
34+
"""Finalize the analysis."""
35+
self.resolve_imports()
36+
if self.progress_bar:
37+
self.progress_bar.close()
38+
39+
40+
def get_diff(git_repo_directory, git_selection) -> Dict[str, set]:
41+
"""
42+
Get the unified diff from git, and return a mapping between file and
43+
line number changed
44+
"""
45+
diff = check_output(
46+
["git", "diff", "--no-prefix", git_selection], cwd=git_repo_directory, text=True
47+
)
48+
patch = PatchSet(diff)
49+
50+
print(patch, file=sys.stderr)
51+
52+
changed_lines = defaultdict(set)
53+
for f in patch:
54+
f: PatchedFile
55+
for hunk in f:
56+
hunk: Hunk
57+
removed = {
58+
l.source_line_no for l in hunk if l.line_type == LINE_TYPE_REMOVED
59+
}
60+
added = {l.target_line_no for l in hunk if l.line_type == LINE_TYPE_ADDED}
61+
changed_lines[f.path] = changed_lines[f.path].union(removed, added)
62+
63+
return changed_lines
64+
65+
66+
class AffectedTestScanner:
67+
"""
68+
scan the call graph to see which test is affected by the changes
69+
"""
70+
71+
def __init__(self, graph, changed_lines_set, root_path):
72+
self.graph = graph
73+
self.changed_lines_set = changed_lines_set
74+
self.scanned_nodes = []
75+
self.current_test = None
76+
self.test_set = set()
77+
self.root_path = root_path
78+
79+
def collect_tests(self) -> list:
80+
for key_node, nodes in self.graph.uses_edges.items():
81+
if key_node.name.startswith("test_"):
82+
self.current_test = key_node
83+
self.check_node_affected(key_node)
84+
self.scan_nodes(nodes)
85+
self.scanned_nodes.clear()
86+
87+
tests = []
88+
for test in self.test_set:
89+
relative_filename = Path(test.filename).relative_to(self.root_path)
90+
namespace = []
91+
for name in reversed(test.namespace.split(".")):
92+
if name == relative_filename.name.rstrip(".py"):
93+
break
94+
namespace += [name]
95+
namespace = "::".join(reversed(namespace))
96+
namespace = f"::{namespace}" if namespace else ""
97+
test_full_name = f"{relative_filename}{namespace}::{test.name}"
98+
tests.append(test_full_name)
99+
100+
return tests
101+
102+
def scan_nodes(self, nodes):
103+
for node in nodes:
104+
if node.flavor in [
105+
Flavor.METHOD,
106+
Flavor.CLASSMETHOD,
107+
Flavor.STATICMETHOD,
108+
Flavor.FUNCTION,
109+
]:
110+
if self.check_node_affected(node):
111+
return True # no point of continue if the test is already marked as affected
112+
if node not in self.scanned_nodes:
113+
self.scanned_nodes.append(node)
114+
if node in self.graph.uses_edges:
115+
if self.scan_nodes(self.graph.uses_edges[node]):
116+
return True # no point of continue if the test is already marked as affected
117+
elif node.flavor == Flavor.IMPORTEDITEM:
118+
if node not in self.scanned_nodes:
119+
self.scanned_nodes.append(node)
120+
if self.scan_nodes(self.graph.nodes[node.name]):
121+
return True # no point of continue if the test is already marked as affected
122+
else:
123+
continue
124+
return False
125+
126+
def check_node_affected(self, node):
127+
if node.ast_node:
128+
p = str(Path(node.filename).relative_to(self.root_path))
129+
for line in self.changed_lines_set[p]:
130+
if node.ast_node.lineno <= int(line) <= node.ast_node.end_lineno:
131+
self.test_set.add(self.current_test)
132+
return True
133+
return False
134+
135+
136+
def run():
137+
logging.basicConfig(level=logging.WARNING)
138+
139+
parser = argparse.ArgumentParser(description="Select tests based on git diff")
140+
parser.add_argument(
141+
"git_diff", default="HEAD", type=str, help="the parameter to pass to `git diff`"
142+
)
143+
parser.add_argument(
144+
"--path", dest="root_path", default=".", help="the path of the git repo to scan"
145+
)
146+
147+
args = parser.parse_args()
148+
root_path = Path(os.path.abspath(args.root_path))
149+
changed_lines_set = get_diff(root_path, args.git_diff)
150+
files_changed = list(str(root_path / f) for f in changed_lines_set.keys())
151+
files_changed = [f for f in files_changed if f.endswith(".py")]
152+
if not files_changed:
153+
print("No python file in the change/diff")
154+
sys.exit()
155+
files = list(str(f) for f in root_path.glob("**/*.py"))
156+
graph = CollectionVisitor(files, str(root_path), logger=logging)
157+
158+
scanner = AffectedTestScanner(graph, changed_lines_set, root_path)
159+
tests = scanner.collect_tests()
160+
for test in tests:
161+
print(test)
162+
163+
sys.exit()
164+
165+
166+
if __name__ == "__main__":
167+
run()

pytest_diff_selector/plugin.py

Whitespace-only changes.

requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
unidiff==0.7.3
2+
pyan3==1.2.0
3+
tqdm==4.62.3
4+
rich==11.2.0

0 commit comments

Comments
 (0)