diff --git a/CoverUp-arxiv-2403.16218.pdf b/CoverUp-arxiv-2403.16218.pdf new file mode 100644 index 0000000..ceeb4b1 Binary files /dev/null and b/CoverUp-arxiv-2403.16218.pdf differ diff --git a/README.md b/README.md index 2eb966a..e3a7145 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ + + by [Juan Altmayer Pizzorno](https://jaltmayerpizzorno.github.io) and [Emery Berger](https://emeryberger.com) at UMass Amherst's [PLASMA lab](https://plasma-umass.org/). @@ -19,28 +21,33 @@ To generate tests, it first measures your suite's coverage using [SlipCover](htt It then selects portions of the code that need more testing (that is, code that is uncovered). CoverUp then engages in a conversation with an [LLM](https://en.wikipedia.org/wiki/Large_language_model), prompting for tests, checking the results to verify that they run and increase coverage (again using SlipCover), and re-prompting for adjustments as necessary. -Finally, CoverUp checks that the new tests integrate well, attempting to resolve any issues it finds. +Finally, CoverUp optionally checks that the new tests integrate well, attempting to resolve any issues it finds. -## Installing CoverUp +For technical details and a complete evaluation, see our arXiv paper, [_CoverUp: Coverage-Guided LLM-Based Test Generation_](https://arxiv.org/abs/2403.16218) ([PDF](https://github.com/plasma-umass/CoverUp/blob/main/CoverUp-arxiv-2403.16218.pdf)). +## Installing CoverUp CoverUp is available from PyPI, so you can install simply with ```shell $ python3 -m pip install coverup ``` -Currently, CoverUp requires an [OpenAI account](https://platform.openai.com/signup) to run (we plan to support local models in the near future). -Your account will also need to have a [positive balance](https://platform.openai.com/account/usage). -Create an [API key](https://platform.openai.com/api-keys) and store its "secret key" (usually a +### LLM model access +CoverUp can be used with OpenAI, Anthropic or AWS Bedrock models; it requires that the +access details be defined as shell environment variables: `OPENAI_API_KEY`, +`ANTHROPIC_API_KEY` or `AWS_ACCESS_KEY_ID`/`AWS_SECRET_ACCESS_KEY`/`AWS_REGION_NAME`, respectively. + +For example, for OpenAI you would create an [account](https://platform.openai.com/signup), ensure +it has a [positive balance](https://platform.openai.com/account/usage) and then create an +an [API key](https://platform.openai.com/api-keys), storing its "secret key" (usually a string starting with `sk-`) in an environment variable named `OPENAI_API_KEY`: ```shell $ export OPENAI_API_KEY=<...your-api-key...> ``` ## Using CoverUp - -If your module's source code is in `src` and your tests in `tests`, you can run CoverUp as +If your module is named `mymod`, its sources are under `src` and the tests under `tests`, you can run CoverUp as ```shell -$ coverup --source-dir src --tests-dir tests +$ coverup --source-dir src/mymod --tests-dir tests ``` CoverUp then creates tests named `test_coverup_N.py`, where `N` is a number, under the `tests` directory. @@ -48,7 +55,7 @@ CoverUp then creates tests named `test_coverup_N.py`, where `N` is a number, und Here we have CoverUp create additional tests for the popular package [Flask](https://flask.palletsprojects.com/): ``` -$ coverup --source-dir src/flask --tests-dir tests +$ coverup --source-dir src/flask --tests-dir tests --disable-polluting --no-isolate-tests Measuring test suite coverage... starting coverage: 90.2% Prompting gpt-4-1106-preview for tests to increase coverage... 100%|███████████████████████████████████████████████████| 95/95 [02:49<00:00, 1.79s/it, usage=~$3.30, G=51, F=141, U=22, R=0] diff --git a/images/comparison.png b/images/comparison.png index 7d11056..4e8d12d 100644 Binary files a/images/comparison.png and b/images/comparison.png differ diff --git a/images/logo-with-title.png b/images/logo-with-title.png new file mode 100644 index 0000000..f286bb7 Binary files /dev/null and b/images/logo-with-title.png differ diff --git a/pyproject.toml b/pyproject.toml index a2d3458..f479576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ dependencies = [ "tiktoken", "aiolimiter", "tqdm", - "llm_utils", - "slipcover>=1.0.3", + "slipcover>=1.0.8", + "pytest-forked", "litellm>=1.33.1" ] diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 0000000..27a8f48 --- /dev/null +++ b/src/.gitignore @@ -0,0 +1 @@ +CoverUp.egg-info diff --git a/src/coverup/coverup.py b/src/coverup/coverup.py index 90097fb..4a7fafd 100644 --- a/src/coverup/coverup.py +++ b/src/coverup/coverup.py @@ -1,6 +1,12 @@ import asyncio import json -import litellm # type: ignore +import warnings + +with warnings.catch_warnings(): + # ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832 + warnings.simplefilter('ignore') + import litellm # type: ignore + import logging import openai import subprocess @@ -14,10 +20,9 @@ from .llm import * from .segment import * from .testrunner import * +from . import prompt -PREFIX = 'coverup' - # Turn off most logging litellm.set_verbose = False logging.getLogger().setLevel(logging.ERROR) @@ -33,7 +38,7 @@ def parse_args(args=None): help='only process certain source file(s)') def Path_dir(value): - path_dir = Path(value) + path_dir = Path(value).resolve() if not path_dir.is_dir(): raise argparse.ArgumentTypeError("must be a directory") return path_dir @@ -48,10 +53,23 @@ def Path_dir(value): ap.add_argument('--no-checkpoint', action='store_const', const=None, dest='checkpoint', default=argparse.SUPPRESS, help=f'disables checkpoint') - ap.add_argument('--model', type=str, + def default_model(): + if 'OPENAI_API_KEY' in os.environ: + return "openai/gpt-4-1106-preview" + if 'ANTHROPIC_API_KEY' in os.environ: + return "anthropic/claude-3-sonnet-20240229" + if 'AWS_ACCESS_KEY_ID' in os.environ: + return "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + + ap.add_argument('--model', type=str, default=default_model(), help='OpenAI model to use') - ap.add_argument('--model-temperature', type=str, default=0, + ap.add_argument('--prompt-family', type=str, + choices=list(prompt.prompters.keys()), + default='gpt', + help='Prompt style to use') + + ap.add_argument('--model-temperature', type=float, default=0, help='Model "temperature" to use') ap.add_argument('--line-limit', type=int, default=50, @@ -74,7 +92,7 @@ def Path_dir(value): action=argparse.BooleanOptionalAction, help=f'show details of lines/branches after each response') - ap.add_argument('--log-file', default=f"{PREFIX}-log", + ap.add_argument('--log-file', default=f"coverup-log", help='log file to use') ap.add_argument('--pytest-args', type=str, default='', @@ -84,20 +102,40 @@ def Path_dir(value): action=argparse.BooleanOptionalAction, help='attempt to install any missing modules') + ap.add_argument('--prefix', type=str, default='coverup', + help='prefix to use for test file names') + ap.add_argument('--write-requirements-to', type=Path, help='append the name of any missing modules to the given file') - ap.add_argument('--failing-test-action', choices=['disable', 'find-culprit'], default='disable', - help='what to do about failing tests when checking the entire suite.') + ap.add_argument('--disable-polluting', default=False, + action=argparse.BooleanOptionalAction, + help='look for tests causing others to fail and disable them') - ap.add_argument('--only-disable-interfering-tests', default=False, + ap.add_argument('--disable-failing', default=False, action=argparse.BooleanOptionalAction, - help='rather than try to add new tests, only look for tests causing others to fail and disable them.') + help='look for failing tests and disable them') + + ap.add_argument('--prompt-for-tests', default=True, + action=argparse.BooleanOptionalAction, + help='prompt LLM for new tests') + + ap.add_argument('--isolate-tests', default=True, + action=argparse.BooleanOptionalAction, + help='run tests in isolation (to work around any state pollution) when measuring suite coverage') ap.add_argument('--debug', '-d', default=False, action=argparse.BooleanOptionalAction, help='print out debugging messages.') + ap.add_argument('--add-to-pythonpath', default=True, + action=argparse.BooleanOptionalAction, + help='add (parent of) source directory to PYTHONPATH') + + ap.add_argument('--branch-coverage', default=True, + action=argparse.BooleanOptionalAction, + help=argparse.SUPPRESS) + def positive_int(value): ivalue = int(value) if ivalue < 0: raise argparse.ArgumentTypeError("must be a number >= 0") @@ -106,13 +144,21 @@ def positive_int(value): ap.add_argument('--max-concurrency', type=positive_int, default=50, help='maximum number of parallel requests; 0 means unlimited') - return ap.parse_args(args) + args = ap.parse_args(args) + + for i in range(len(args.source_files)): + args.source_files[i] = args.source_files[i].resolve() + + if args.disable_failing and args.disable_polluting: + ap.error('Specify only one of --disable-failing and --disable-polluting') + + return args def test_file_path(test_seq: int) -> Path: """Returns the Path for a test's file, given its sequence number.""" global args - return args.tests_dir / f"test_{PREFIX}_{test_seq}.py" + return args.tests_dir / f"test_{args.prefix}_{test_seq}.py" test_seq = 1 @@ -163,44 +209,59 @@ def log_write(seg: CodeSegment, m: str) -> None: log_file.write(f"---- {datetime.now().isoformat(timespec='seconds')} {seg} ----\n{m}\n") -def disable_interfering_tests() -> dict: - """While the test suite fails, disables any interfering tests. - If the test suite succeeds, returns the coverage observed.""" +def check_whole_suite() -> None: + """Check whole suite and disable any polluting/failing tests.""" + + pytest_args = args.pytest_args + if args.disable_polluting: + pytest_args += " -x" # stop at first (to save time) while True: - print("Checking test suite... ", end='') + print("Checking test suite... ", end='', flush=True) try: - coverage = measure_suite_coverage(tests_dir=args.tests_dir, source_dir=args.source_dir, - pytest_args=args.pytest_args, - trace=(print if args.debug else None)) - print("tests ok!") - return coverage + btf = BadTestsFinder(tests_dir=args.tests_dir, pytest_args=pytest_args, + branch_coverage=args.branch_coverage, + trace=(print if args.debug else None)) + outcomes = btf.run_tests() + failing_tests = list(p for p, o in outcomes.items() if o == 'failed') + if not failing_tests: + print("tests ok!") + return except subprocess.CalledProcessError as e: - failing_test = parse_failed_tests(args.tests_dir, e)[0] + print(str(e) + "\n" + str(e.stdout, 'UTF-8', errors='ignore')) + sys.exit(1) - btf = BadTestsFinder(tests_dir=args.tests_dir, pytest_args=args.pytest_args, - trace=(print if args.debug else None)) - - if args.failing_test_action == 'disable': - # just disable failing test(s) while we work on BTF - print(f"failed ({failing_test}). Looking for failing tests(s) to disable...") - culprits = btf.run_tests() + if args.disable_failing: + print(f"{len(failing_tests)} test(s) failed, disabling...") + to_disable = failing_tests else: - print(f"{failing_test} is failing, looking for culprit(s)...") - if btf.run_tests({failing_test}) == {failing_test}: - print(f"{failing_test} fails by itself(!)") - culprits = {failing_test} - else: - culprits = btf.find_culprit(failing_test) + print(f"{failing_tests[0]} failed; Looking for culprit(s)...") + + def print_noeol(message): + # ESC[K clears the rest of the line + print(message, end='...\033[K\r', flush=True) + + try: + btf = BadTestsFinder(tests_dir=args.tests_dir, pytest_args=args.pytest_args, + branch_coverage=args.branch_coverage, + trace=(print if args.debug else None), + progress=(print if args.debug else print_noeol)) + + to_disable = btf.find_culprit(failing_tests[0]) - for c in culprits: - print(f"Disabling {c}") - c.rename(c.parent / ("disabled_" + c.name)) + except BadTestFinderError as e: + print(e) + to_disable = {failing_tests[0]} + + for t in to_disable: + print(f"Disabling {t}") + t.rename(t.parent / ("disabled_" + t.name)) def find_imports(python_code: str) -> T.List[str]: + """Collects a list of packages needed by a program by examining its 'import' statements""" import ast try: @@ -217,7 +278,8 @@ def find_imports(python_code: str) -> T.List[str]: modules.append(name.name.split('.')[0]) elif isinstance(n, ast.ImportFrom): - modules.append(n.module.split('.')[0]) + if n.module and n.level == 0: + modules.append(n.module.split('.')[0]) return modules @@ -245,7 +307,7 @@ def install_missing_imports(seg: CodeSegment, modules: T.List[str]) -> bool: print(f"Installed module {module}") log_write(seg, f"Installed module {module}") except subprocess.CalledProcessError as e: - log_write(seg, f"Unable to install module {module}:\n{e.stdout}") + log_write(seg, f"Unable to install module {module}:\n{str(e.stdout, 'UTF-8', errors='ignore')}") all_ok = False return all_ok @@ -256,16 +318,6 @@ def get_required_modules() -> T.List[str]: return [m for m in module_available if module_available[m] != 1] -def get_module_name(src_file: Path, src_dir: Path) -> str: - try: - src_file = Path(src_file) - src_dir = Path(src_dir) - relative = src_file.resolve().relative_to(src_dir.resolve()) - return ".".join((src_dir.stem,) + relative.parts[:-1] + (relative.stem,)) - except ValueError: - return None # not relative to source - - PROGRESS_COUNTERS=['G', 'F', 'U', 'R'] # good, failed, useless, retry class Progress: """Tracks progress, showing a tqdm-based bar.""" @@ -357,12 +409,12 @@ def inc_counter(self, key: str): def mark_done(self, seg: CodeSegment): """Marks a segment done.""" - self.done[seg.filename].add((seg.begin, seg.end)) + self.done[seg.path].add((seg.begin, seg.end)) def is_done(self, seg: CodeSegment): """Returns whether a segment is done.""" - return (seg.begin, seg.end) in self.done[seg.filename] + return (seg.begin, seg.end) in self.done[seg.path] @staticmethod @@ -377,7 +429,7 @@ def load_checkpoint(ckpt_file: Path): # -> State state = State(ckpt['coverage']) for filename, done_list in ckpt['done'].items(): - state.done[filename] = set(tuple(d) for d in done_list) + state.done[Path(filename).resolve()] = set(tuple(d) for d in done_list) state.add_usage(ckpt['usage']) if 'counters' in ckpt: state.counters = ckpt['counters'] @@ -391,7 +443,7 @@ def save_checkpoint(self, ckpt_file: Path): """Saves this state to a checkpoint file.""" ckpt = { 'version': 1, - 'done': {k:list(v) for k,v in self.done.items() if len(v)}, # cannot serialize 'set' as-is + 'done': {str(k):list(v) for k,v in self.done.items() if len(v)}, # cannot serialize 'Path' or 'set' as-is 'usage': self.usage, 'counters': self.counters, 'coverage': self.coverage @@ -424,7 +476,9 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str: return await litellm.acreate(**completion) - except (openai.RateLimitError, openai.APITimeoutError) as e: + except (litellm.exceptions.ServiceUnavailableError, + openai.RateLimitError, + openai.APITimeoutError) as e: # This message usually indicates out of money in account if 'You exceeded your current quota' in str(e): @@ -444,11 +498,17 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str: log_write(seg, f"Error: {type(e)} {e}") return None # gives up this segment - except (ConnectionError) as e: + except openai.APIConnectionError as e: log_write(seg, f"Error: {type(e)} {e}") # usually a server-side error... just retry right away state.inc_counter('R') + except openai.APIError as e: + # APIError is the base class for all API errors; + # we may be missing a more specific handler. + print(f"Error: {type(e)} {e}; missing handler?") + log_write(seg, f"Error: {type(e)} {e}") + return None # gives up this segment def extract_python(response: str) -> str: # This regex accepts a truncated code block... this seems fine since we'll try it anyway @@ -461,35 +521,13 @@ async def improve_coverage(seg: CodeSegment) -> bool: """Works to improve coverage for a code segment.""" global args, progress - def pl(item, singular, plural = None): - if len(item) <= 1: - return singular - return plural if plural is not None else f"{singular}s" - - module_name = get_module_name(seg.filename, args.source_dir) - - messages = [{"role": "user", - "content": f""" -You are an expert Python test-driven developer. -The code below, extracted from {seg.filename},{' module ' + module_name + ',' if module_name else ''} does not achieve full coverage: -when tested, {seg.lines_branches_missing_do()} not execute. -Create a new pytest test function that executes these missing lines/branches, always making -sure that the new test is correct and indeed improves coverage. -Always send entire Python test scripts when proposing a new test or correcting one you -previously proposed. -Be sure to include assertions in the test that verify any applicable postconditions. -Please also make VERY SURE to clean up after the test, so as not to affect other tests; -use 'pytest-mock' if appropriate. -Write as little top-level code as possible, and in particular do not include any top-level code -calling into pytest.main or the test itself. -Respond ONLY with the Python code enclosed in backticks, without any explanation. -```python -{seg.get_excerpt()} -``` -""" - }] - - log_write(seg, messages[0]['content']) # initial prompt + def log_prompts(prompts: T.List[dict]): + for p in prompts: + log_write(seg, p['content']) + + prompter = prompt.prompters[args.prompt_family](args=args, segment=seg) + messages = prompter.initial_prompt() + log_prompts(messages) attempts = 0 @@ -502,8 +540,14 @@ def pl(item, singular, plural = None): log_write(seg, "Too many attempts, giving up") break - if not (response := await do_chat(seg, {'model': args.model, 'messages': messages, - 'temperature': args.model_temperature})): + completion = {'model': args.model, + 'messages': messages, + 'temperature': args.model_temperature} + + if "ollama" in args.model: + completion["api_base"] = "http://localhost:11434" + + if not (response := await do_chat(seg, completion)): log_write(seg, "giving up") break @@ -528,6 +572,7 @@ def pl(item, singular, plural = None): try: result = await measure_test_coverage(test=last_test, tests_dir=args.tests_dir, pytest_args=args.pytest_args, + branch_coverage=args.branch_coverage, log_write=lambda msg: log_write(seg, msg)) except subprocess.TimeoutExpired: @@ -538,18 +583,14 @@ def pl(item, singular, plural = None): except subprocess.CalledProcessError as e: state.inc_counter('F') - messages.append({ - "role": "user", - "content": "Executing the test yields an error, shown below.\n" +\ - "Modify the test to correct it; respond only with the complete Python code in backticks.\n\n" +\ - clean_error(str(e.stdout, 'UTF-8', errors='ignore')) - }) - log_write(seg, messages[-1]['content']) + prompts = prompter.error_prompt(clean_error(str(e.stdout, 'UTF-8', errors='ignore'))) + messages.extend(prompts) + log_prompts(prompts) continue new_lines = set(result[seg.filename]['executed_lines']) if seg.filename in result else set() new_branches = set(tuple(b) for b in result[seg.filename]['executed_branches']) \ - if seg.filename in result else set() + if (seg.filename in result and 'executed_branches' in result[seg.filename]) else set() now_missing_lines = seg.missing_lines - new_lines now_missing_branches = seg.missing_branches - new_branches @@ -562,15 +603,10 @@ def pl(item, singular, plural = None): # XXX insist on len(now_missing_lines)+len(now_missing_branches) == 0 ? if len(now_missing_lines)+len(now_missing_branches) == seg.missing_count(): - messages.append({ - "role": "user", - "content": f""" -This test still lacks coverage: {lines_branches_do(now_missing_lines, set(), now_missing_branches)} not execute. -Modify it to correct that; respond only with the complete Python code in backticks. -""" - }) - log_write(seg, messages[-1]['content']) state.inc_counter('U') + prompts = prompter.missing_coverage_prompt(now_missing_lines, now_missing_branches) + messages.extend(prompts) + log_prompts(prompts) continue # the test is good 'nough... @@ -595,7 +631,6 @@ def add_to_pythonpath(source_dir: Path): def main(): - from collections import defaultdict import os @@ -607,154 +642,174 @@ def main(): return 1 # add source dir to paths so that the module doesn't need to be installed to be worked on - add_to_pythonpath(args.source_dir) - - if args.only_disable_interfering_tests: - disable_interfering_tests() - return - - if args.rate_limit or token_rate_limit_for_model(args.model): - limit = (args.rate_limit, 60) if args.rate_limit else token_rate_limit_for_model(args.model) - from aiolimiter import AsyncLimiter - token_rate_limit = AsyncLimiter(*limit) - # TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both - - - # Check for an API key for OpenAI or Amazon Bedrock. - if 'OPENAI_API_KEY' not in os.environ: - if not all(x in os.environ for x in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION_NAME']): - print("You need a key (or keys) from an AI service to use CoverUp.") - print() - print("OpenAI:") - print(" You can get a key here: https://platform.openai.com/api-keys") - print(" Set the environment variable OPENAI_API_KEY to your key value:") - print(" export OPENAI_API_KEY=") - print() - print() - print("Bedrock:") - print(" To use Bedrock, you need an AWS account.") - print(" Set the following environment variables:") - print(" export AWS_ACCESS_KEY_ID=") - print(" export AWS_SECRET_ACCESS_KEY=") - print(" export AWS_REGION_NAME=us-west-2") - print(" You also need to request access to Claude:") - print( - " https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access" - ) - print() - return 1 + if args.add_to_pythonpath: + add_to_pythonpath(args.source_dir) + + if args.prompt_for_tests: + if args.rate_limit or token_rate_limit_for_model(args.model): + limit = (args.rate_limit, 60) if args.rate_limit else token_rate_limit_for_model(args.model) + from aiolimiter import AsyncLimiter + token_rate_limit = AsyncLimiter(*limit) + # TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both + + + # Check for an API key for OpenAI or Amazon Bedrock. + if 'OPENAI_API_KEY' not in os.environ and 'ANTHROPIC_API_KEY' not in os.environ: + if not all(x in os.environ for x in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION_NAME']): + print("You need a key (or keys) from an AI service to use CoverUp.") + print() + print("OpenAI:") + print(" You can get a key here: https://platform.openai.com/api-keys") + print(" Set the environment variable OPENAI_API_KEY to your key value:") + print(" export OPENAI_API_KEY=") + print() + print() + print("Anthropic:") + print(" Set the environment variable ANTHROPIC_API_KEY to your key value:") + print(" export ANTHROPIC_API_KEY=") + print() + print() + print("Bedrock:") + print(" To use Bedrock, you need an AWS account.") + print(" Set the following environment variables:") + print(" export AWS_ACCESS_KEY_ID=") + print(" export AWS_SECRET_ACCESS_KEY=") + print(" export AWS_REGION_NAME=us-west-2") + print(" You also need to request access to Claude:") + print( + " https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access" + ) + print() + return 1 - if 'OPENAI_API_KEY' in os.environ: - if not args.model: - # args.model = "openai/gpt-4" - args.model = "openai/gpt-4-1106-preview" - # openai.key=os.environ['OPENAI_API_KEY'] - #if 'OPENAI_ORGANIZATION' in os.environ: - # openai.organization=os.environ['OPENAI_ORGANIZATION'] - else: - # args.model = "bedrock/anthropic.claude-v2:1" if not args.model: - args.model = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" - log_write('startup', f"Command: {' '.join(sys.argv)}") + print("Please specify model to use with --model") + return 1 - # --- (1) load or measure initial coverage, figure out segmentation --- + log_write('startup', f"Command: {' '.join(sys.argv)}") - if args.checkpoint and (state := State.load_checkpoint(args.checkpoint)): - print("Continuing from checkpoint; ", end='') - else: - try: - coverage = disable_interfering_tests() + # --- (1) load or measure initial coverage, figure out segmentation --- - except subprocess.CalledProcessError as e: - print("Error measuring coverage:\n" + str(e.stdout, 'UTF-8', errors='ignore')) - return 1 + if args.checkpoint and (state := State.load_checkpoint(args.checkpoint)): + print("Continuing from checkpoint; coverage: ", end='', flush=True) + coverage = state.get_initial_coverage() + else: + if args.disable_polluting or args.disable_failing: + # check and clean up suite before measuring coverage + check_whole_suite() - state = State(coverage) + try: + print("Measuring coverage... ", end='', flush=True) + coverage = measure_suite_coverage(tests_dir=args.tests_dir, source_dir=args.source_dir, + pytest_args=args.pytest_args, + isolate_tests=args.isolate_tests, + branch_coverage=args.branch_coverage, + trace=(print if args.debug else None)) + state = State(coverage) - coverage = state.get_initial_coverage() + except subprocess.CalledProcessError as e: + print("Error measuring coverage:\n" + str(e.stdout, 'UTF-8', errors='ignore')) + return 1 - print(f"Initial coverage: {coverage['summary']['percent_covered']:.1f}%") - # TODO also show running coverage estimate + print(f"{coverage['summary']['percent_covered']:.1f}%") + # TODO also show running coverage estimate - segments = sorted(get_missing_coverage(state.get_initial_coverage(), line_limit=args.line_limit), - key=lambda seg: seg.missing_count(), reverse=True) + segments = sorted(get_missing_coverage(state.get_initial_coverage(), line_limit=args.line_limit), + key=lambda seg: seg.missing_count(), reverse=True) - # save initial coverage so we don't have to redo it next time - if args.checkpoint: - state.save_checkpoint(args.checkpoint) + # save initial coverage so we don't have to redo it next time + if args.checkpoint: + state.save_checkpoint(args.checkpoint) - # --- (2) prompt for tests --- + # --- (2) prompt for tests --- - print(f"Prompting {args.model} for tests to increase coverage...") - print("(in the following, G=good, F=failed, U=useless and R=retry)") + print(f"Prompting {args.model} for tests to increase coverage...") + print("(in the following, G=good, F=failed, U=useless and R=retry)") - async def work_segment(seg: CodeSegment) -> None: - if await improve_coverage(seg): - # Only mark done if was able to complete (True return), - # so that it can be retried after installing any missing modules - state.mark_done(seg) + async def work_segment(seg: CodeSegment) -> None: + if await improve_coverage(seg): + # Only mark done if was able to complete (True return), + # so that it can be retried after installing any missing modules + state.mark_done(seg) - if args.checkpoint: - state.save_checkpoint(args.checkpoint) - progress.signal_one_completed() + if args.checkpoint: + state.save_checkpoint(args.checkpoint) + progress.signal_one_completed() - worklist = [] - seg_done_count = 0 - for seg in segments: - if not args.source_dir in Path(seg.filename).parents: - continue + worklist = [] + seg_done_count = 0 + for seg in segments: + if not seg.path.is_relative_to(args.source_dir): + continue - if args.source_files and all(seg.filename not in str(s) for s in args.source_files): - continue + if args.source_files and seg.path not in args.source_files: + continue - if state.is_done(seg): - seg_done_count += 1 - else: - worklist.append(work_segment(seg)) + if state.is_done(seg): + seg_done_count += 1 + else: + worklist.append(work_segment(seg)) - progress = Progress(total=len(worklist)+seg_done_count, initial=seg_done_count) - state.set_progress_bar(progress) + progress = Progress(total=len(worklist)+seg_done_count, initial=seg_done_count) + state.set_progress_bar(progress) - async def run_it(): - if args.max_concurrency: - semaphore = asyncio.Semaphore(args.max_concurrency) + async def run_it(): + if args.max_concurrency: + semaphore = asyncio.Semaphore(args.max_concurrency) - async def sem_coro(coro): - async with semaphore: - return await coro + async def sem_coro(coro): + async with semaphore: + return await coro - await asyncio.gather(*(sem_coro(c) for c in worklist)) - else: - await asyncio.gather(*worklist) + await asyncio.gather(*(sem_coro(c) for c in worklist)) + else: + await asyncio.gather(*worklist) - try: - asyncio.run(run_it()) - except KeyboardInterrupt: - print("Interrupted.") - if args.checkpoint: - state.save_checkpoint(args.checkpoint) - return 1 + try: + asyncio.run(run_it()) + except KeyboardInterrupt: + print("Interrupted.") + if args.checkpoint: + state.save_checkpoint(args.checkpoint) + return 1 + + progress.close() + + # --- (3) clean up resulting test suite --- + + if args.disable_polluting or args.disable_failing: + check_whole_suite() + + # --- (4) show final coverage - progress.close() + if args.prompt_for_tests: + try: + print("Measuring coverage... ", end='', flush=True) + coverage = measure_suite_coverage(tests_dir=args.tests_dir, source_dir=args.source_dir, + pytest_args=args.pytest_args, + isolate_tests=args.isolate_tests, + branch_coverage=args.branch_coverage, + trace=(print if args.debug else None)) - # --- (3) check resulting test suite --- + except subprocess.CalledProcessError as e: + print("Error measuring coverage:\n" + str(e.stdout, 'UTF-8', errors='ignore')) + return 1 - coverage = disable_interfering_tests() - print(f"End coverage: {coverage['summary']['percent_covered']:.1f}%") + print(f"{coverage['summary']['percent_covered']:.1f}%") - if args.checkpoint: - state.set_final_coverage(coverage) - state.save_checkpoint(args.checkpoint) + # --- (5) save state and show missing modules, if appropriate - # --- (4) final remarks --- + if args.checkpoint: + state.set_final_coverage(coverage) + state.save_checkpoint(args.checkpoint) - if required := get_required_modules(): - # Sometimes GPT outputs 'from your_module import XYZ', asking us to modify - # FIXME move this to 'state' - print(f"Some modules seem missing: {', '.join(str(m) for m in required)}") - if args.write_requirements_to: - with args.write_requirements_to.open("a") as f: - for module in required: - f.write(f"{module}\n") + if required := get_required_modules(): + # Sometimes GPT outputs 'from your_module import XYZ', asking us to modify + # FIXME move this to 'state' + print(f"Some modules seem missing: {', '.join(str(m) for m in required)}") + if args.write_requirements_to: + with args.write_requirements_to.open("a") as f: + for module in required: + f.write(f"{module}\n") return 0 diff --git a/src/coverup/delta.py b/src/coverup/delta.py index d08692a..6902a66 100644 --- a/src/coverup/delta.py +++ b/src/coverup/delta.py @@ -29,7 +29,7 @@ def get_ranges(): yield str(a) if a == b else f"{a}-{b}" a = n - return ", ".join(list(get_ranges()) + sorted(names)) + return "{" + ", ".join(list(get_ranges()) + sorted(names)) + "}" class DeltaDebugger(abc.ABC): @@ -54,10 +54,12 @@ def debug(self, changes: set, rest: set = set(), **kwargs) -> set: change_list = sorted(changes) c1 = set(change_list[:len_changes//2]) + if self.trace: self.trace(f"checking {_compact(change_list[:len_changes//2])}") if self.test(c1.union(rest), **kwargs): return self.debug(c1, rest, **kwargs) # in 1st half c2 = set(change_list[len_changes//2:]) + if self.trace: self.trace(f"checking {_compact(change_list[len_changes//2:])}") if self.test(c2.union(rest), **kwargs): return self.debug(c2, rest, **kwargs) # in 2nd half diff --git a/src/coverup/llm.py b/src/coverup/llm.py index a233969..749316f 100644 --- a/src/coverup/llm.py +++ b/src/coverup/llm.py @@ -1,5 +1,5 @@ import typing as T -import llm_utils +import litellm # Tier 5 rate limits for models; tuples indicate limit and interval in seconds @@ -45,32 +45,40 @@ def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]: + if model_name.startswith('openai/'): + model_name = model_name[7:] + if (model_limits := MODEL_RATE_LIMITS.get(model_name)): return model_limits.get('token') return None -def compute_cost(usage: dict, model: str) -> float: +def compute_cost(usage: dict, model_name: str) -> float: from math import ceil - if 'prompt_tokens' in usage and 'completion_tokens' in usage: - try: - return llm_utils.calculate_cost(usage['prompt_tokens'], usage['completion_tokens'], model) + if model_name.startswith('openai/'): + model_name = model_name[7:] - except ValueError: - pass # unknown model + if 'prompt_tokens' in usage and 'completion_tokens' in usage: + if (cost := litellm.model_cost.get(model_name)): + return usage['prompt_tokens'] * cost['input_cost_per_token'] +\ + usage['completion_tokens'] * cost['output_cost_per_token'] return None _token_encoding_cache = dict() -def count_tokens(model: str, completion: dict): +def count_tokens(model_name: str, completion: dict): """Counts the number of tokens in a chat completion request.""" import tiktoken - if not (encoding := _token_encoding_cache.get(model)): - encoding = _token_encoding_cache[model] = tiktoken.encoding_for_model(model) + if not (encoding := _token_encoding_cache.get(model_name)): + model = model_name + if model_name.startswith('openai/'): + model = model_name[7:] + + encoding = _token_encoding_cache[model_name] = tiktoken.encoding_for_model(model) count = 0 for m in completion['messages']: diff --git a/src/coverup/prompt.py b/src/coverup/prompt.py new file mode 100644 index 0000000..b7bb0d4 --- /dev/null +++ b/src/coverup/prompt.py @@ -0,0 +1,223 @@ +import abc +from pathlib import Path +from .utils import lines_branches_do +from .segment import CodeSegment +import typing as T + + +def get_module_name(src_file: Path, src_dir: Path) -> str: + # assumes both src_file and src_dir Path.resolve()'d + try: + relative = src_file.relative_to(src_dir) + return ".".join((src_dir.stem,) + relative.parts[:-1] + (relative.stem,)) + except ValueError: + return None # not relative to source + + +class Prompter(abc.ABC): + """Interface for a CoverUp prompter.""" + + def __init__(self, args, segment: CodeSegment): + self.args = args + self.segment = segment + + + @abc.abstractmethod + def initial_prompt(self) -> T.List[dict]: + """Returns initial prompt(s) for a code segment.""" + + + @abc.abstractmethod + def error_prompt(self, error: str) -> T.List[dict]: + """Returns prompts(s) in response to an error.""" + + + @abc.abstractmethod + def missing_coverage_prompt(self) -> T.List[dict]: + """Returns prompts(s) in response to the suggested test lacking coverage.""" + + +def _message(content: str, *, role="user") -> dict: + return { + 'role': role, + 'content': content + } + + +class Gpt4PrompterV1(Prompter): + """Prompter for GPT-4 used in paper submission.""" + + def __init__(self, *args, **kwargs): + Prompter.__init__(self, *args, **kwargs) + + + def initial_prompt(self) -> T.List[dict]: + args = self.args + seg = self.segment + module_name = get_module_name(seg.path, args.source_dir) + + return [ + _message(f""" +You are an expert Python test-driven developer. +The code below, extracted from {seg.filename},{' module ' + module_name + ',' if module_name else ''} does not achieve full coverage: +when tested, {seg.lines_branches_missing_do()} not execute. +Create a new pytest test function that executes these missing lines/branches, always making +sure that the new test is correct and indeed improves coverage. +Always send entire Python test scripts when proposing a new test or correcting one you +previously proposed. +Be sure to include assertions in the test that verify any applicable postconditions. +Please also make VERY SURE to clean up after the test, so as not to affect other tests; +use 'pytest-mock' if appropriate. +Write as little top-level code as possible, and in particular do not include any top-level code +calling into pytest.main or the test itself. +Respond ONLY with the Python code enclosed in backticks, without any explanation. +```python +{seg.get_excerpt()} +``` +""") + ] + + def error_prompt(self, error: str) -> T.List[dict]: + return [_message(f"""\ +Executing the test yields an error, shown below. +Modify the test to correct it; respond only with the complete Python code in backticks. + +{error}""") + ] + + + def missing_coverage_prompt(self, now_missing_lines: set, now_missing_branches: set) -> T.List[dict]: + return [_message(f"""\ +This test still lacks coverage: {lines_branches_do(now_missing_lines, set(), now_missing_branches)} not execute. +Modify it to correct that; respond only with the complete Python code in backticks. +""") + ] + + +class Gpt4Prompter(Prompter): + """Prompter for GPT-4.""" + + def __init__(self, *args, **kwargs): + Prompter.__init__(self, *args, **kwargs) + + + def initial_prompt(self) -> T.List[dict]: + args = self.args + seg = self.segment + module_name = get_module_name(seg.path, args.source_dir) + filename = seg.path.relative_to(args.source_dir.parent) + + return [ + _message(f""" +You are an expert Python test-driven developer. +The code below, extracted from {filename}, does not achieve full coverage: +when tested, {seg.lines_branches_missing_do()} not execute. +Create new pytest test functions that execute these missing lines/branches, always making +sure that the tests are correct and indeed improve coverage. +Always send entire Python test scripts when proposing a new test or correcting one you +previously proposed. +Be sure to include assertions in the test that verify any applicable postconditions. +Please also make VERY SURE to clean up after the test, so as to avoid state pollution; +use 'monkeypatch' or 'pytest-mock' if appropriate. +Write as little top-level code as possible, and in particular do not include any top-level code +calling into pytest.main or the test itself. +Respond ONLY with the Python code enclosed in backticks, without any explanation. +```python +{seg.get_excerpt()} +``` +""") + ] + + + def error_prompt(self, error: str) -> T.List[dict]: + return [_message(f"""\ +Executing the test yields an error, shown below. +Modify the test to correct it; respond only with the complete Python code in backticks. + +{error}""") + ] + + + def missing_coverage_prompt(self, now_missing_lines: set, now_missing_branches: set) -> T.List[dict]: + return [_message(f"""\ +This test still lacks coverage: {lines_branches_do(now_missing_lines, set(), now_missing_branches)} not execute. +Modify it to correct that; respond only with the complete Python code in backticks. +""") + ] + +class ClaudePrompter(Prompter): + """Prompter for Claude.""" + + def __init__(self, *args, **kwargs): + Prompter.__init__(self, *args, **kwargs) + + + def initial_prompt(self) -> T.List[str]: + args = self.args + seg = self.segment + module_name = get_module_name(seg.path, args.source_dir) + + return [ + _message("You are an expert Python test-driven developer who creates pytest test functions that achieve high coverage.", + role="system"), + _message(f""" + +{seg.get_excerpt()} + + + + +The code above does not achieve full coverage: +when tested, {seg.lines_branches_missing_do()} not execute. + +1. Create a new pytest test function that executes these missing lines/branches, always making +sure that the new test is correct and indeed improves coverage. + +2. Always send entire Python test scripts when proposing a new test or correcting one you +previously proposed. + +3. Be sure to include assertions in the test that verify any applicable postconditions. + +4. Please also make VERY SURE to clean up after the test, so as not to affect other tests; +use 'pytest-mock' if appropriate. + +5. Write as little top-level code as possible, and in particular do not include any top-level code +calling into pytest.main or the test itself. + +6. Respond with the Python code enclosed in backticks. Before answering the question, please think about it step-by-step within tags. Then, provide your final answer within tags. + +""") + ] + + + def error_prompt(self, error: str) -> T.List[dict]: + return [_message(f"""\ +{error} +Executing the test yields an error, shown above. + +1. Modify the test to correct it. +2. Respond with the complete Python code in backticks. +3. Before answering the question, please think about it step-by-step within tags. Then, provide your final answer within tags. + +""") + ] + + + def missing_coverage_prompt(self, now_missing_lines: set, now_missing_branches: set) -> T.List[dict]: + return [_message(f"""\ +This test still lacks coverage: {lines_branches_do(now_missing_lines, set(), now_missing_branches)} not execute. + +1. Modify it to execute those lines. +2. Respond with the complete Python code in backticks. +3. Before responding, please think about it step-by-step within tags. Then, provide your final answer within tags. + +""") + ] + + +# prompter registry +prompters = { + "gpt-v1": Gpt4PrompterV1, + "gpt": Gpt4Prompter, + "claude": ClaudePrompter +} diff --git a/src/coverup/pytest_plugin.py b/src/coverup/pytest_plugin.py new file mode 100644 index 0000000..01b5bea --- /dev/null +++ b/src/coverup/pytest_plugin.py @@ -0,0 +1,80 @@ +import json +import typing as T +import pytest +import json +from _pytest.pathlib import Path + + +def pytest_addoption(parser): + parser.addoption('--coverup-outcomes', action='store', + type=Path, default=None, + help='Path where to store execution outcomes') + parser.addoption('--coverup-run-only', action='store', + type=Path, default=None, + help='Path to only module to execute') + parser.addoption('--coverup-stop-after', action='store', + type=Path, default=None, + help='Stop execution after reaching this module') + + +class CoverUpPlugin: + def __init__(self, config): + self._rootpath = config.rootpath + self._stop_after = config.getoption("--coverup-stop-after") + self._outcomes_file = config.getoption("--coverup-outcomes") + + if self._stop_after: self._stop_after = self._stop_after.resolve() + + self._stop_now = False + self._outcomes = {} + + def pytest_collectreport(self, report): + if report.failed: + path = self._rootpath / report.fspath + if path not in self._outcomes: + self._outcomes[path] = report.outcome + + def pytest_collection_modifyitems(self, config, items): + if (run_only := config.getoption("--coverup-run-only")): + run_only = run_only.resolve() + + selected = [] + deselected = [] + + for item in items: + if run_only == item.path: + selected.append(item) + else: + deselected.append(item) + + items[:] = selected + if deselected: + config.hook.pytest_deselected(items=deselected) + + def pytest_runtest_protocol(self, item, nextitem): + if self._stop_after == item.path: + if not nextitem or self._stop_after != nextitem.path: + self._stop_now = True + + def pytest_runtest_logreport(self, report): + path = self._rootpath / report.fspath + if path not in self._outcomes or report.outcome != 'passed': + self._outcomes[path] = report.outcome + + if self._stop_now and report.when == 'teardown': + pytest.exit(f"Stopped after {self._stop_after}") + + def write_outcomes(self): + if self._outcomes_file: + with self._outcomes_file.open("w") as f: + json.dump({str(k): v for k, v in self._outcomes.items()}, f) + + +def pytest_configure(config): + config._coverup_plugin = CoverUpPlugin(config) + config.pluginmanager.register(config._coverup_plugin, 'coverup_plugin') + + +def pytest_unconfigure(config): + if (plugin := getattr(config, "_coverup_plugin", None)): + plugin.write_outcomes() diff --git a/src/coverup/segment.py b/src/coverup/segment.py index 9da6791..fabc64e 100644 --- a/src/coverup/segment.py +++ b/src/coverup/segment.py @@ -12,6 +12,7 @@ def __init__(self, filename: Path, name: str, begin: int, end: int, executed_lines: T.Set[int], missing_branches: T.Set[T.Tuple[int, int]], context: T.List[T.Tuple[int, int]]): + self.path = Path(filename).resolve() self.filename = filename self.name = name self.begin = begin diff --git a/src/coverup/testrunner.py b/src/coverup/testrunner.py index ab386f0..b5f6e21 100644 --- a/src/coverup/testrunner.py +++ b/src/coverup/testrunner.py @@ -11,7 +11,7 @@ from .utils import subprocess_run -async def measure_test_coverage(*, test: str, tests_dir: Path, pytest_args='', log_write=None): +async def measure_test_coverage(*, test: str, tests_dir: Path, pytest_args='', log_write=None, branch_coverage=True): """Runs a given test and returns the coverage obtained.""" with tempfile.NamedTemporaryFile(prefix="tmp_test_", suffix='.py', dir=str(tests_dir), mode="w") as t: t.write(test) @@ -20,12 +20,14 @@ async def measure_test_coverage(*, test: str, tests_dir: Path, pytest_args='', l with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as j: try: # -qq to cut down on tokens - p = await subprocess_run((f"{sys.executable} -m slipcover --branch --json --out {j.name} " + - f"-m pytest {pytest_args} -qq -x --disable-warnings {t.name}").split(), - check=True, timeout=60) + p = await subprocess_run([sys.executable, '-m', 'slipcover'] + (['--branch'] if branch_coverage else []) + \ + ['--json', '--out', j.name, + '-m', 'pytest'] + pytest_args.split() + ['-qq', '-x', '--disable-warnings', t.name], + check=True, timeout=60) if log_write: log_write(str(p.stdout, 'UTF-8', errors='ignore')) + # not checking for JSON errors here because if pytest aborts, its RC ought to be !=0 cov = json.load(j) finally: j.close() @@ -37,53 +39,45 @@ async def measure_test_coverage(*, test: str, tests_dir: Path, pytest_args='', l return cov["files"] -def measure_suite_coverage(*, tests_dir: Path, source_dir: Path, pytest_args='', trace=None): +def measure_suite_coverage(*, tests_dir: Path, source_dir: Path, pytest_args='', trace=None, isolate_tests=False, branch_coverage=True): """Runs an entire test suite and returns the coverage obtained.""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as j: try: - p = subprocess.run((f"{sys.executable} -m slipcover --source {source_dir} --branch --json --out {j.name} " + - f"-m pytest {pytest_args} -qq -x --disable-warnings {tests_dir}").split(), - check=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - - if trace: - trace(f"tests rc={p.returncode}\n") - trace(str(p.stdout, 'UTF-8', errors='ignore')) - + command = [sys.executable, + '-m', 'slipcover', '--source', source_dir] + (['--branch'] if branch_coverage else []) + \ + ['--json', '--out', j.name] + \ + (['--isolate-tests'] if isolate_tests else []) + \ + ['-m', 'pytest'] + pytest_args.split() + ['--disable-warnings', '-x', tests_dir] + + if trace: trace(command) + p = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) if p.returncode not in (pytest.ExitCode.OK, pytest.ExitCode.NO_TESTS_COLLECTED): + if trace: trace(f"tests rc={p.returncode}") p.check_returncode() - return json.load(j) + try: + return json.load(j) + except json.decoder.JSONDecodeError: + # The JSON is broken, so pytest's execution likely aborted (e.g. a Python unhandled exception). + p.check_returncode() # this will almost certainly raise an exception. If not, we do it ourselves: + raise subprocess.CalledProcessError(p.returncode, command, output=p.stdout) finally: j.close() + try: os.unlink(j.name) except FileNotFoundError: pass -class ParseError(Exception): +class BadTestFinderError(Exception): pass -def parse_failed_tests(tests_dir: Path, p: (subprocess.CompletedProcess, subprocess.CalledProcessError)) -> T.List[Path]: - # FIXME use --junitxml or --report-log - output = str(p.stdout, 'UTF-8', errors='ignore') - if (m := re.search(r"\n===+ short test summary info ===+\n((?:ERROR|FAILED).*)", output, re.DOTALL)): - summary = m.group(1) - failures = [Path(f) for f in re.findall(r"^(?:ERROR|FAILED) ([^\s:]+)", summary, re.MULTILINE)] - - if tests_dir.is_absolute(): - # pytest sometimes makes absolute paths into relative ones by adding ../../.. to root... - failures = [f.resolve() for f in failures] - - return failures - - raise ParseError(f"Unable to parse failing tests out of pytest output. RC={p.returncode}; output:\n{output}") - - class BadTestsFinder(DeltaDebugger): """Finds tests that cause other tests to fail.""" - def __init__(self, *, tests_dir: Path, pytest_args: str = '', trace = None): + def __init__(self, *, tests_dir: Path, pytest_args: str = '', trace = None, progress = None, branch_coverage=True): super(BadTestsFinder, self).__init__(trace=trace) self.tests_dir = tests_dir @@ -97,10 +91,12 @@ def find_tests(p): yield f self.all_tests = set(find_tests(self.tests_dir)) - self.pytest_args = pytest_args + self.pytest_args = pytest_args.split() + self.branch_coverage = branch_coverage + self.progress = progress - def run_tests(self, tests_to_run: set = None) -> Path: + def run_tests(self, tests_to_run: set = None, stop_after: Path = None, run_only: Path = None) -> Path: """Runs the tests, by default all, returning the first one that fails, or None. Throws RuntimeError if unable to parse pytest's output. """ @@ -121,51 +117,100 @@ def link_tree(src, dst): assert self.tests_dir.parent != self.tests_dir # we need a parent directory - if self.trace: self.trace(f"running {len(test_set)} test(s).") + if self.progress: self.progress(f"running {'1/' if run_only else ('up to ' if stop_after else '')}{len(test_set)} test(s)") with tempfile.TemporaryDirectory(dir=self.tests_dir.parent) as tmpdir: tmpdir = Path(tmpdir) link_tree(self.tests_dir, tmpdir) - p = subprocess.run((f"{sys.executable} -m pytest {self.pytest_args} -qq --disable-warnings " +\ - f"--rootdir {tmpdir} {tmpdir}").split(), - check=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, timeout=2*60*60) + def to_tmpdir(p: Path): + return tmpdir / (p.relative_to(self.tests_dir)) - if p.returncode in (pytest.ExitCode.OK, pytest.ExitCode.NO_TESTS_COLLECTED): - if self.trace: self.trace(f"tests passed") - return set() - - # bring it back to its normal path - failing = set(self.tests_dir / f.relative_to(tmpdir) for f in parse_failed_tests(tmpdir, p)) + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as outcomes_f: + try: + command = [sys.executable, + # include SlipCover in case a test is failing because of it + # (in measure_suite_coverage) + '-m', 'slipcover'] + (['--branch'] if self.branch_coverage else []) + \ + ['--out', '/dev/null', + "-m", "pytest"] + self.pytest_args + \ + ['-qq', '--disable-warnings', + '-p', 'coverup.pytest_plugin', '--coverup-outcomes', str(outcomes_f.name)] \ + + (['--coverup-stop-after', str(to_tmpdir(stop_after))] if stop_after else []) \ + + (['--coverup-run-only', str(to_tmpdir(run_only))] if run_only else []) \ + + [str(tmpdir)] +# if self.trace: self.trace(' '.join(command)) + p = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, timeout=2*60*60) + + if p.returncode not in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED, + pytest.ExitCode.INTERRUPTED, pytest.ExitCode.NO_TESTS_COLLECTED): + self.trace(f"tests rc={p.returncode}") + p.check_returncode() + + try: + outcomes = json.load(outcomes_f) + except json.decoder.JSONDecodeError: + # The JSON is broken, so pytest's execution likely aborted (e.g. a Python unhandled exception). + p.check_returncode() # this will almost certainly raise an exception. If not, we do it ourselves: + raise subprocess.CalledProcessError(p.returncode, command, output=p.stdout) + + finally: + outcomes_f.close() + try: + os.unlink(outcomes_f.name) + except FileNotFoundError: + pass + + tmpdir = tmpdir.resolve() + return {self.tests_dir / Path(p).relative_to(tmpdir): o for p, o in outcomes.items()} - if self.trace: - self.trace(f"tests rc={p.returncode} failing={_compact(failing)}\n") -# self.trace(str(p.stdout, 'UTF-8', errors='ignore')) - return failing + def test(self, test_set: set, **kwargs) -> bool: + target_test = kwargs.get('target_test') + outcomes = self.run_tests(test_set, stop_after=target_test, run_only=kwargs.get('run_only')) + if self.trace: self.trace(f"failing={_compact(set(p for p, o in outcomes.items() if o == 'failed'))}") + while target_test not in outcomes: + if self.trace: self.trace(f"{target_test} not executed, trying without failing tests.") + test_set -= set(p for p, o in outcomes.items() if o == 'failed') - def test(self, test_set: set, **kwargs) -> bool: - if not (failing_tests := self.run_tests(test_set)): - return False + if not test_set: + raise BadTestFinderError(f"Unable to narrow down cause of {target_test} failure.") - # Other tests may fail; we need to know if "our" test failed. - if kwargs.get('target_test') not in failing_tests: - return False + outcomes = self.run_tests(test_set, stop_after=target_test, run_only=kwargs.get('run_only')) + if self.trace: self.trace(f"failing={_compact(set(p for p, o in outcomes.items() if o == 'failed'))}") - return True # "interesting"/"reproduced" + return outcomes[kwargs.get('target_test')] == 'failed' def find_culprit(self, failing_test: Path, *, test_set = None) -> T.Set[Path]: """Returns the set of tests causing 'failing_test' to fail.""" assert failing_test in self.all_tests - # TODO first test collection using --collect-only, with short timeout - # TODO reduce timeout for actually running tests - -# we unfortunately can't do this... the code that causes test(s) to fail may execute during pytest collection. -# sorted_tests = sorted(self.all_tests) -# test_set = set(sorted_tests[:sorted_tests.index(failing_test)]) -# assert self.test(test_set), "Test set doesn't fail!" - - changes = set(test_set if test_set is not None else self.all_tests) - {failing_test} - return self.debug(changes=changes, rest={failing_test}, target_test=failing_test) + if self.trace: self.trace(f"checking that {failing_test} still fails...") + outcomes = self.run_tests(stop_after=failing_test) + if outcomes[failing_test] != 'failed': + if self.trace: self.trace("it doesn't!") + raise BadTestFinderError(f"Unable to narrow down causes of failure") + + if self.trace: self.trace(f"checking if {failing_test} passes by itself...") + outcomes = self.run_tests({failing_test}) + if outcomes[failing_test] != 'passed': + if self.trace: self.trace("it doesn't!") + return {failing_test} + + if self.trace: self.trace("checking if failure is caused by test collection code...") + tests_to_run = set(test_set if test_set is not None else self.all_tests) - {failing_test} + outcomes = self.run_tests(tests_to_run.union({failing_test}), run_only=failing_test) + + if outcomes[failing_test] == 'failed': + if self.trace: print("Issue is in test collection code; looking for culprit...") + culprits = self.debug(changes=tests_to_run, rest={failing_test}, + target_test=failing_test, run_only=failing_test) + else: + if self.trace: print("Issue is in test run code; looking for culprit...") + culprits = self.debug(changes=tests_to_run, rest={failing_test}, target_test=failing_test) + + if culprits == tests_to_run: + raise BadTestFinderError(f"Unable to narrow down causes of failure") + + return culprits diff --git a/src/coverup/utils.py b/src/coverup/utils.py index 47b0601..8ea536a 100644 --- a/src/coverup/utils.py +++ b/src/coverup/utils.py @@ -3,26 +3,6 @@ import subprocess -class TemporaryOverwrite: - """Context handler that overwrites a file, and restores it upon exit.""" - def __init__(self, file: Path, new_content: str): - self.file = file - self.new_content = new_content - self.backup = file.parent / (file.name + ".bak") if file.exists() else None - - def __enter__(self): - if self.file.exists(): - self.file.replace(self.backup) - - self.file.write_text(self.new_content) - self.file.touch() - - def __exit__(self, exc_type, exc_value, traceback): - self.file.unlink() - if self.backup: - self.backup.replace(self.file) - - def format_ranges(lines: T.Set[int], negative: T.Set[int]) -> str: """Formats sets of line numbers as comma-separated lists, collapsing neighboring lines into ranges for brevity.""" diff --git a/tests/test_coverup.py b/tests/test_coverup.py index 8b28c3e..c79b213 100644 --- a/tests/test_coverup.py +++ b/tests/test_coverup.py @@ -83,10 +83,11 @@ def test_clean_error_error(): def test_find_imports(): - assert ['abc', 'bar', 'baz', 'cba', 'foo', 'xy'] == sorted(coverup.find_imports("""\ + assert ['abc', 'bar', 'baz', 'cba', 'foo'] == sorted(coverup.find_imports("""\ import foo, bar.baz from baz.zab import a, b, c -from ..xy import yz +from ..xy import yz # relative, package likely present +from . import blah # relative, package likely present def foo_func(): import abc @@ -102,12 +103,6 @@ def test_missing_imports(): assert coverup.missing_imports(['sys', 'idontexist']) -def test_get_module_name(): - assert 'flask.json.provider' == coverup.get_module_name(Path('src/flask/json/provider.py'), Path('src/flask')) - assert 'flask.json.provider' == coverup.get_module_name('src/flask/json/provider.py', 'src/flask') - assert 'flask.tool' == coverup.get_module_name('src/flask/tool.py', './tests/../src/flask') - assert None == coverup.get_module_name('src/flask/tool.py', './tests') - def test_extract_python(): assert "foo()\n\nbar()\n" == coverup.extract_python("""\ ```python diff --git a/tests/test_llm.py b/tests/test_llm.py index ec583fc..e971493 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -6,9 +6,15 @@ def test_compute_cost(): assert pytest.approx(0.033, abs=.001) == \ llm.compute_cost({'prompt_tokens':1100, 'completion_tokens':0}, 'gpt-4') + assert pytest.approx(0.033, abs=.001) == \ + llm.compute_cost({'prompt_tokens':1100, 'completion_tokens':0}, 'openai/gpt-4') + assert pytest.approx(2.10, abs=.01) == \ llm.compute_cost({'prompt_tokens':60625, 'completion_tokens':4731}, 'gpt-4') + assert pytest.approx(2.10, abs=.01) == \ + llm.compute_cost({'prompt_tokens':60625, 'completion_tokens':4731}, 'openai/gpt-4') + # unknown model assert None == llm.compute_cost({'prompt_tokens':60625, 'completion_tokens':4731}, 'unknown') diff --git a/tests/test_prompt.py b/tests/test_prompt.py new file mode 100644 index 0000000..a702e3c --- /dev/null +++ b/tests/test_prompt.py @@ -0,0 +1,13 @@ +from pathlib import Path +from coverup import prompt + + +def test_get_module_name(): + fpath = Path('src/flask/json/provider.py').resolve() + srcpath = Path('src/flask').resolve() + + assert 'flask.json.provider' == prompt.get_module_name(fpath, srcpath) + + assert None == prompt.get_module_name(fpath, Path('./tests').resolve()) + + diff --git a/tests/test_testrunner.py b/tests/test_testrunner.py index 3255d04..d02f542 100644 --- a/tests/test_testrunner.py +++ b/tests/test_testrunner.py @@ -2,6 +2,7 @@ import subprocess import coverup.testrunner as tr from pathlib import Path +import tempfile @pytest.mark.asyncio @@ -9,23 +10,31 @@ async def test_measure_test_coverage_exit_1(tmpdir): with pytest.raises(subprocess.CalledProcessError) as einfo: await tr.measure_test_coverage(test="import os;\ndef test_foo(): os.exit(1)\n", tests_dir=Path(tmpdir)) -def test_measure_suite_coverage_empty_dir(tmpdir): - coverage = tr.measure_suite_coverage(tests_dir=Path(tmpdir), source_dir=Path('src')) # shouldn't throw - assert coverage['summary']['covered_lines'] == 0 + +@pytest.mark.parametrize("absolute", [True, False]) +def test_measure_suite_coverage_empty_dir(absolute): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + if absolute: + tests_dir = tests_dir.resolve() + + coverage = tr.measure_suite_coverage(tests_dir=tests_dir, source_dir=tests_dir, trace=print) # shouldn't throw + assert {} == coverage['files'] def seq2p(tests_dir, seq): return tests_dir / f"test_coverup_{seq}.py" +N_TESTS=10 def make_failing_suite(tests_dir: Path, fail_collect: bool): """In a suite with 10 tests, test 6 fails; test 3 doesn't fail, but causes 6 to fail.""" - for seq in range(10): + for seq in range(N_TESTS): seq2p(tests_dir, seq).write_text('def test_foo(): pass') culprit = seq2p(tests_dir, 3) - culprit.write_text("import sys\n" + "sys.foobar = True") + culprit.write_text("import sys\n" + "sys.foobar = True\n" + "def test_foo(): pass") failing = seq2p(tests_dir, 6) if fail_collect: @@ -37,16 +46,32 @@ def make_failing_suite(tests_dir: Path, fail_collect: bool): @pytest.mark.parametrize("fail_collect", [True, False]) -def test_measure_suite_coverage_test_fails(tmpdir, fail_collect): +@pytest.mark.parametrize("absolute", [True, False]) +def test_measure_suite_coverage_test_fails(absolute, fail_collect): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + if absolute: + tests_dir = tests_dir.resolve() - tests_dir = Path(tmpdir) + failing, culprit = make_failing_suite(tests_dir, fail_collect) - failing, culprit = make_failing_suite(tests_dir, fail_collect) + with pytest.raises(subprocess.CalledProcessError) as einfo: + tr.measure_suite_coverage(tests_dir=tests_dir, source_dir=Path('src'), isolate_tests=False) - with pytest.raises(subprocess.CalledProcessError) as einfo: - tr.measure_suite_coverage(tests_dir=tests_dir, source_dir=Path('src')) - assert [failing] == tr.parse_failed_tests(tests_dir, einfo.value) +@pytest.mark.parametrize("fail_collect", [True, False]) +@pytest.mark.parametrize("absolute", [True, False]) +def test_measure_suite_coverage_isolated(absolute, fail_collect): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + if absolute: + tests_dir = tests_dir.resolve() + + failing, culprit = make_failing_suite(tests_dir, fail_collect) + + tr.measure_suite_coverage(tests_dir=tests_dir, source_dir=Path('src'), isolate_tests=True) + + # FIXME check coverage def test_finds_tests_in_subdir(tmpdir): @@ -63,83 +88,124 @@ def test_finds_tests_in_subdir(tmpdir): @pytest.mark.parametrize("fail_collect", [True, False]) -def test_run_tests(tmpdir, fail_collect): - tests_dir = Path(tmpdir) +@pytest.mark.parametrize("absolute", [True, False]) +def test_run_tests(absolute, fail_collect): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + if absolute: + tests_dir = tests_dir.resolve() - failing, _ = make_failing_suite(tests_dir, fail_collect) + failing, _ = make_failing_suite(tests_dir, fail_collect) - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - assert {failing} == btf.run_tests() + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + outcomes = btf.run_tests() + if fail_collect: + assert len(outcomes) == 1 + else: + assert len(outcomes) == N_TESTS -def test_run_tests_multiple_failures(tmpdir): - tests_dir = Path(tmpdir) + assert {failing} == set(p for p, o in outcomes.items() if o != 'passed') - for seq in range(10): - seq2p(tests_dir, seq).write_text("import sys\n" + "def test_foo(): assert not getattr(sys, 'foobar', False)") - culprit = seq2p(tests_dir, 3) - culprit.write_text("import sys\n" + "sys.foobar = True") +@pytest.mark.parametrize("absolute", [True, False]) +@pytest.mark.parametrize("fail_collect", [True, False]) +def test_run_tests_run_single(absolute, fail_collect): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + if absolute: + tests_dir = tests_dir.resolve() - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - failing = btf.run_tests() + failing, _ = make_failing_suite(tests_dir, fail_collect=fail_collect) - for seq in range(10): - if seq != 3: assert seq2p(tests_dir, seq) in failing + non_failing = seq2p(tests_dir, 2) + assert non_failing != failing -def test_run_tests_no_tests(tmpdir): - tests_dir = Path(tmpdir) + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + outcomes = btf.run_tests(run_only=non_failing) - (tests_dir / "test_foo.py").write_text("# no tests here") + assert len(outcomes) == 1 + if fail_collect: + assert outcomes[failing] != 'passed' + else: + assert outcomes[non_failing] == 'passed' - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - failed = btf.run_tests() - assert failed == set() + +def test_run_tests_multiple_failures(): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + + for seq in range(10): + seq2p(tests_dir, seq).write_text("import sys\n" + "def test_foo(): assert not getattr(sys, 'foobar', False)") + + culprit = seq2p(tests_dir, 3) + culprit.write_text("import sys\n" + "sys.foobar = True") + + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + outcomes = btf.run_tests() + + assert len(outcomes) == 9 # no tests in 3 + for seq in range(10): + if seq != 3: assert outcomes[seq2p(tests_dir, seq)] != 'passed' + + +def test_run_tests_no_tests(): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) + + (tests_dir / "test_foo.py").write_text("# no tests here") + + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + outcomes = btf.run_tests() + assert len(outcomes) == 0 @pytest.mark.parametrize("fail_collect", [True, False]) -def test_find_culprit(tmpdir, fail_collect): - tests_dir = Path(tmpdir) +def test_find_culprit(fail_collect): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) - failing, culprit = make_failing_suite(tests_dir, fail_collect) + failing, culprit = make_failing_suite(tests_dir, fail_collect) - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - assert not btf.run_tests({culprit}) - assert {culprit} == btf.find_culprit(failing) + assert 'passed' == btf.run_tests({culprit})[culprit] + assert {culprit} == btf.find_culprit(failing) -def test_find_culprit_multiple_failures(tmpdir): - tests_dir = Path(tmpdir) +def test_find_culprit_multiple_failures(): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) - for seq in range(10): - seq2p(tests_dir, seq).write_text("import sys\n" + "def test_foo(): assert not getattr(sys, 'foobar', False)") + for seq in range(10): + seq2p(tests_dir, seq).write_text("import sys\n" + "def test_foo(): assert not getattr(sys, 'foobar', False)") - culprit = seq2p(tests_dir, 3) - culprit.write_text("import sys\n" + "sys.foobar = True") + culprit = seq2p(tests_dir, 3) + culprit.write_text("import sys\n" + "sys.foobar = True") - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - assert {culprit} == btf.find_culprit(seq2p(tests_dir, 6)) + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + assert {culprit} == btf.find_culprit(seq2p(tests_dir, 6)) @pytest.mark.skip(reason="no good handling for this yet, it takes a long time") -def test_find_culprit_hanging_collect(tmpdir): - tests_dir = Path(tmpdir) +def test_find_culprit_hanging_collect(): + with tempfile.TemporaryDirectory(dir=Path('.')) as tests_dir: + tests_dir = Path(tests_dir) - all_tests = {seq2p(tests_dir, seq) for seq in range(10)} - for t in all_tests: - t.write_text('def test_foo(): pass') + all_tests = {seq2p(tests_dir, seq) for seq in range(10)} + for t in all_tests: + t.write_text('def test_foo(): pass') - culprit = seq2p(tests_dir, 3) - culprit.write_text("""\ -import pytest + culprit = seq2p(tests_dir, 3) + culprit.write_text("""\ + import pytest -def test_foo(): pass + def test_foo(): pass -pytest.main(["--verbose"]) -pytest.main(["--verbose"]) -pytest.main(["--verbose"]) -""") + pytest.main(["--verbose"]) + pytest.main(["--verbose"]) + pytest.main(["--verbose"]) + """) - btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) - btf.run_tests() + btf = tr.BadTestsFinder(tests_dir=tests_dir, trace=print) + btf.run_tests()