-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
786cbc0
commit 279b342
Showing
20 changed files
with
618 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# file src/cwhy/conversation/explain_functions.py:12-13 | ||
# lines [12, 13] | ||
# branches [] | ||
|
||
import argparse | ||
import pytest | ||
from cwhy.conversation.explain_functions import ExplainFunctions | ||
|
||
# Test function to cover ExplainFunctions.__init__ | ||
def test_explain_functions_init(): | ||
# Create a mock argparse.Namespace object | ||
mock_args = argparse.Namespace() | ||
# Instantiate ExplainFunctions with mock_args | ||
explain_func = ExplainFunctions(mock_args) | ||
# Assert that the args attribute is correctly set | ||
assert explain_func.args is mock_args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# file src/cwhy/conversation/diff_functions.py:48-79 | ||
# lines [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 60, 61, 63, 64, 65, 67, 68, 69, 72] | ||
# branches [] | ||
|
||
import pytest | ||
from cwhy.conversation.diff_functions import DiffFunctions | ||
|
||
class MockArgs: | ||
pass | ||
|
||
def test_apply_modification_schema(): | ||
mock_args = MockArgs() | ||
diff_functions = DiffFunctions(mock_args) | ||
schema = diff_functions.apply_modification_schema() | ||
assert schema['name'] == 'apply_modification' | ||
assert schema['description'] == 'Applies a single modification to the source file with the goal of fixing any existing compilation errors.' | ||
assert 'parameters' in schema | ||
params = schema['parameters'] | ||
assert params['type'] == 'object' | ||
assert 'properties' in params | ||
properties = params['properties'] | ||
assert 'filename' in properties | ||
assert properties['filename']['type'] == 'string' | ||
assert 'start-line-number' in properties | ||
assert properties['start-line-number']['type'] == 'integer' | ||
assert 'number-lines-remove' in properties | ||
assert properties['number-lines-remove']['type'] == 'integer' | ||
assert 'replacement' in properties | ||
assert properties['replacement']['type'] == 'string' | ||
assert 'required' in params | ||
required_fields = params['required'] | ||
assert 'filename' in required_fields | ||
assert 'start-line-number' in required_fields | ||
assert 'number-lines-remove' in required_fields | ||
assert 'replacement' in required_fields |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# file src/cwhy/prompts.py:194-198 | ||
# lines [194, 195, 196, 197] | ||
# branches [] | ||
|
||
import argparse | ||
from cwhy.prompts import explain_prompt | ||
import pytest | ||
from unittest.mock import patch | ||
|
||
# Assuming the _base_prompt function is something like this: | ||
# def _base_prompt(args: argparse.Namespace, diagnostic: str) -> str: | ||
# return f"Diagnostic: {diagnostic}\n" | ||
|
||
# Test function to cover explain_prompt | ||
def test_explain_prompt(): | ||
# Mock the _base_prompt function to control its output | ||
with patch('cwhy.prompts._base_prompt', return_value="Mocked base prompt. ") as mock_base_prompt: | ||
# Create a mock argparse.Namespace object | ||
mock_args = argparse.Namespace() | ||
diagnostic = "Sample diagnostic message." | ||
|
||
# Call the function under test | ||
result = explain_prompt(mock_args, diagnostic) | ||
|
||
# Verify the result includes the mocked base prompt and the additional text | ||
expected_result = "Mocked base prompt. What's the problem? If you can, suggest code to fix the issue." | ||
assert result == expected_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# file src/cwhy/__main__.py:51-58 | ||
# lines [51, 52, 53, 54, 56, 57, 58] | ||
# branches ['57->exit', '57->58'] | ||
|
||
import argparse | ||
from unittest.mock import MagicMock | ||
import pytest | ||
from cwhy.__main__ import RichArgParser | ||
|
||
def test_rich_arg_parser_print_message(monkeypatch): | ||
# Mock the Console object and its print method | ||
mock_console = MagicMock() | ||
monkeypatch.setattr('cwhy.__main__.Console', mock_console) | ||
|
||
# Create an instance of RichArgParser | ||
parser = RichArgParser() | ||
|
||
# Call the _print_message method with a test message | ||
test_message = "Test message" | ||
parser._print_message(test_message) | ||
|
||
# Assert that the console's print method was called with the test message | ||
mock_console.return_value.print.assert_called_once_with(test_message) | ||
|
||
# Call the _print_message method with None to test the branch where message is not printed | ||
parser._print_message(None) | ||
|
||
# Assert that the console's print method was still called only once (with the previous message) | ||
assert mock_console.return_value.print.call_count == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# file src/cwhy/conversation/explain_functions.py:49-52 | ||
# lines [49, 50, 51, 52] | ||
# branches [] | ||
|
||
import pytest | ||
from src.cwhy.conversation.explain_functions import ExplainFunctions | ||
from unittest.mock import MagicMock, patch | ||
|
||
@pytest.fixture | ||
def explain_functions(): | ||
mock_args = MagicMock() | ||
mock_args.command = ["python", "script.py"] | ||
return ExplainFunctions(args=mock_args) | ||
|
||
def test_get_compile_or_run_command(explain_functions): | ||
with patch('src.cwhy.conversation.explain_functions.dprint') as mock_dprint: | ||
# Call the method and assert the result | ||
result = explain_functions.get_compile_or_run_command() | ||
assert result == "python script.py" | ||
|
||
# Assert that dprint was called with the correct argument | ||
mock_dprint.assert_called_once_with("python script.py") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# file src/cwhy/conversation/explain_functions.py:15-23 | ||
# lines [15, 16, 17, 18, 19, 20, 21] | ||
# branches [] | ||
|
||
import pytest | ||
from unittest.mock import MagicMock | ||
from cwhy.conversation.explain_functions import ExplainFunctions | ||
|
||
class TestExplainFunctions: | ||
def test_as_tools(self, monkeypatch): | ||
# Create a MagicMock to simulate the schema methods | ||
monkeypatch.setattr(ExplainFunctions, 'get_compile_or_run_command_schema', MagicMock(return_value={'name': 'compile_or_run'})) | ||
monkeypatch.setattr(ExplainFunctions, 'get_code_surrounding_schema', MagicMock(return_value={'name': 'code_surrounding'})) | ||
monkeypatch.setattr(ExplainFunctions, 'list_directory_schema', MagicMock(return_value={'name': 'list_directory'})) | ||
|
||
# Mock the __init__ method to not require arguments | ||
monkeypatch.setattr(ExplainFunctions, '__init__', lambda x, args: None) | ||
|
||
ef = ExplainFunctions(args=None) | ||
tools = ef.as_tools() | ||
|
||
# Verify that the list contains all three schemas | ||
assert len(tools) == 3 | ||
assert tools[0]['function']['name'] == 'compile_or_run' | ||
assert tools[1]['function']['name'] == 'code_surrounding' | ||
assert tools[2]['function']['name'] == 'list_directory' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# file src/cwhy/conversation/utils.py:7-34 | ||
# lines [7, 11, 12, 13, 14, 16, 17, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34] | ||
# branches ['19->20', '19->32', '20->21', '20->24', '28->19', '28->29', '32->33', '32->34'] | ||
|
||
import argparse | ||
from unittest.mock import patch | ||
import pytest | ||
|
||
# Assuming llm_utils.count_tokens is a function that needs to be mocked | ||
from cwhy.conversation import llm_utils | ||
from cwhy.conversation.utils import get_truncated_error_message | ||
|
||
@pytest.fixture | ||
def mock_count_tokens(): | ||
with patch('cwhy.conversation.llm_utils.count_tokens') as mock: | ||
mock.side_effect = lambda llm, text: len(text.split()) | ||
yield mock | ||
|
||
def test_get_truncated_error_message(mock_count_tokens): | ||
args = argparse.Namespace() | ||
args.llm = "mock_llm" | ||
args.max_error_tokens = 10 # Set a max token limit for the test | ||
|
||
# Create a diagnostic message that will exceed the max_error_tokens when combined | ||
diagnostic = "Error on line 1\nError on line 2\nError on line 3\nError on line 4" | ||
|
||
truncated_message = get_truncated_error_message(args, diagnostic) | ||
|
||
# Check that the message was indeed truncated | ||
assert "[...]" in truncated_message | ||
|
||
# Check that the truncated message does not exceed max_error_tokens | ||
assert len(truncated_message.split()) <= args.max_error_tokens | ||
|
||
# Check that the first and last lines are present in the truncated message | ||
assert "Error on line 1" in truncated_message | ||
assert "Error on line 4" in truncated_message | ||
|
||
# Check that the middle lines are not present in the truncated message | ||
assert "Error on line 2" not in truncated_message | ||
assert "Error on line 3" not in truncated_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# file src/cwhy/prompts.py:201-205 | ||
# lines [202, 203, 204] | ||
# branches [] | ||
|
||
import argparse | ||
import pytest | ||
from unittest.mock import patch | ||
from cwhy.prompts import diff_prompt | ||
|
||
# Assuming _base_prompt is a function that needs to be mocked | ||
# and that it is imported in the test file as well | ||
|
||
@pytest.fixture | ||
def mock_base_prompt(): | ||
with patch('cwhy.prompts._base_prompt', return_value="Base prompt output. ") as mock: | ||
yield mock | ||
|
||
def test_diff_prompt(mock_base_prompt): | ||
# Create a Namespace object to simulate argparse arguments | ||
args = argparse.Namespace() | ||
diagnostic = "Some diagnostic message" | ||
|
||
# Call the function under test | ||
result = diff_prompt(args, diagnostic) | ||
|
||
# Verify that the _base_prompt function was called with the correct arguments | ||
mock_base_prompt.assert_called_once_with(args, diagnostic) | ||
|
||
# Verify the result includes the expected additional string | ||
expected_result = "Base prompt output. Help fix this issue by providing a diff in JSON format." | ||
assert result == expected_result | ||
|
||
# No cleanup is necessary as the mock ensures no side effects on other tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# file src/cwhy/print_debug.py:28-30 | ||
# lines [30] | ||
# branches [] | ||
|
||
import pytest | ||
from cwhy.print_debug import enable_debug_printing | ||
|
||
# Assuming that _debug is a global variable within the cwhy.print_debug module | ||
# and that there is no direct way to check if _debug is True other than calling | ||
# enable_debug_printing and observing its effects. | ||
|
||
def test_enable_debug_printing(monkeypatch): | ||
# Use monkeypatch to set _debug to False before the test | ||
monkeypatch.setattr('cwhy.print_debug._debug', False) | ||
|
||
# Call the function that should set _debug to True | ||
enable_debug_printing() | ||
|
||
# Import _debug to check its value after enable_debug_printing() is called | ||
from cwhy.print_debug import _debug | ||
assert _debug == True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# file src/cwhy/cwhy.py:143-159 | ||
# lines [144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 159] | ||
# branches ['144->145', '144->147', '147->148', '147->149', '149->150', '149->154', '154->155', '154->156', '156->157', '156->159'] | ||
|
||
import pytest | ||
from src.cwhy.cwhy import evaluate | ||
from unittest.mock import MagicMock, patch | ||
|
||
# Define a test function to cover lines 144-159 | ||
@pytest.fixture | ||
def mock_evaluate_with_fallback(): | ||
with patch('src.cwhy.cwhy.evaluate_with_fallback') as mock: | ||
yield mock | ||
|
||
@pytest.fixture | ||
def mock_evaluate_text_prompt(): | ||
with patch('src.cwhy.cwhy.evaluate_text_prompt') as mock: | ||
yield mock | ||
|
||
@pytest.fixture | ||
def mock_evaluate_diff(): | ||
with patch('src.cwhy.cwhy.evaluate_diff') as mock: | ||
yield mock | ||
|
||
@pytest.fixture | ||
def mock_converse(): | ||
with patch('src.cwhy.cwhy.conversation.converse') as mock: | ||
yield mock | ||
|
||
@pytest.fixture | ||
def mock_diff_converse(): | ||
with patch('src.cwhy.cwhy.conversation.diff_converse') as mock: | ||
yield mock | ||
|
||
@pytest.fixture | ||
def mock_prompts_explain_prompt(): | ||
with patch('src.cwhy.cwhy.prompts.explain_prompt') as mock: | ||
yield mock | ||
|
||
def test_evaluate_coverage(mock_evaluate_with_fallback, mock_evaluate_text_prompt, mock_evaluate_diff, mock_converse, mock_diff_converse, mock_prompts_explain_prompt): | ||
# Create a mock args object with the necessary attributes | ||
mock_args = MagicMock() | ||
mock_stdin = MagicMock() | ||
|
||
# Test the "default" llm branch | ||
mock_args.llm = "default" | ||
evaluate(mock_args, mock_stdin) | ||
mock_evaluate_with_fallback.assert_called_once_with(mock_args, mock_stdin) | ||
|
||
# Test the "explain" subcommand branch | ||
mock_args.llm = "non-default" | ||
mock_args.subcommand = "explain" | ||
evaluate(mock_args, mock_stdin) | ||
mock_prompts_explain_prompt.assert_called_once_with(mock_args, mock_stdin) | ||
mock_evaluate_text_prompt.assert_called_once_with(mock_args, mock_prompts_explain_prompt.return_value) | ||
|
||
# Test the "diff" subcommand branch | ||
mock_args.subcommand = "diff" | ||
mock_evaluate_diff.return_value.choices = [MagicMock(message=MagicMock(tool_calls=[MagicMock(function=MagicMock(arguments='diff-args'))]))] | ||
result = evaluate(mock_args, mock_stdin) | ||
assert result == 'diff-args' | ||
|
||
# Test the "converse" subcommand branch | ||
mock_args.subcommand = "converse" | ||
evaluate(mock_args, mock_stdin) | ||
mock_converse.assert_called_once_with(mock_args, mock_stdin) | ||
|
||
# Test the "diff-converse" subcommand branch | ||
mock_args.subcommand = "diff-converse" | ||
evaluate(mock_args, mock_stdin) | ||
mock_diff_converse.assert_called_once_with(mock_args, mock_stdin) | ||
|
||
# Test the unknown subcommand branch | ||
mock_args.subcommand = "unknown" | ||
with pytest.raises(Exception) as exc_info: | ||
evaluate(mock_args, mock_stdin) | ||
assert str(exc_info.value) == f"unknown subcommand: {mock_args.subcommand}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# file src/cwhy/cwhy.py:26-43 | ||
# lines [27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42] | ||
# branches [] | ||
|
||
import pytest | ||
from unittest.mock import patch, call | ||
from src.cwhy.cwhy import print_key_info | ||
|
||
def test_print_key_info(): | ||
with patch('src.cwhy.cwhy.dprint') as mock_dprint: | ||
print_key_info() | ||
assert mock_dprint.call_count == 15 | ||
calls = [ | ||
call("You need a key (or keys) from an AI service to use CWhy."), | ||
call(), | ||
call("OpenAI:"), | ||
call(" You can get a key here: https://platform.openai.com/api-keys"), | ||
call(" Set the environment variable OPENAI_API_KEY to your key value:"), | ||
call(" export OPENAI_API_KEY=<your key>"), | ||
call(), | ||
call("Bedrock:"), | ||
call(" To use Bedrock, you need an AWS account."), | ||
call(" Set the following environment variables:"), | ||
call(" export AWS_ACCESS_KEY_ID=<your key id>"), | ||
call(" export AWS_SECRET_ACCESS_KEY=<your secret key>"), | ||
call(" export AWS_REGION_NAME=us-west-2"), | ||
call(" You also need to request access to Claude:"), | ||
call( | ||
" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access" | ||
), | ||
] | ||
mock_dprint.assert_has_calls(calls, any_order=True) |
Oops, something went wrong.