Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to merge to a specific target branch instead of main #5109

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions openhands/resolver/send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def send_pull_request(
pr_type: str,
fork_owner: str | None = None,
additional_message: str | None = None,
target_branch: str | None = None,
) -> str:
if pr_type not in ['branch', 'draft', 'ready']:
raise ValueError(f'Invalid pr_type: {pr_type}')
Expand All @@ -224,12 +225,19 @@ def send_pull_request(
attempt += 1
branch_name = f'{base_branch_name}-try{attempt}'

# Get the default branch
print('Getting default branch...')
response = requests.get(f'{base_url}', headers=headers)
response.raise_for_status()
default_branch = response.json()['default_branch']
print(f'Default branch: {default_branch}')
# Get the default branch or use specified target branch
print('Getting base branch...')
if target_branch:
base_branch = target_branch
# Verify the target branch exists
response = requests.get(f'{base_url}/branches/{target_branch}', headers=headers)
if response.status_code != 200:
raise ValueError(f'Target branch {target_branch} does not exist')
else:
response = requests.get(f'{base_url}', headers=headers)
response.raise_for_status()
base_branch = response.json()['default_branch']
print(f'Base branch: {base_branch}')

# Create and checkout the new branch
print('Creating new branch...')
Expand Down Expand Up @@ -279,7 +287,7 @@ def send_pull_request(
'title': pr_title, # No need to escape title for GitHub API
'body': pr_body,
'head': branch_name,
'base': default_branch,
'base': base_branch,
'draft': pr_type == 'draft',
}
response = requests.post(f'{base_url}/pulls', headers=headers, json=data)
Expand Down Expand Up @@ -435,6 +443,7 @@ def process_single_issue(
llm_config: LLMConfig,
fork_owner: str | None,
send_on_failure: bool,
target_branch: str | None = None,
) -> None:
if not resolver_output.success and not send_on_failure:
print(
Expand Down Expand Up @@ -484,6 +493,7 @@ def process_single_issue(
llm_config=llm_config,
fork_owner=fork_owner,
additional_message=resolver_output.success_explanation,
target_branch=target_branch,
)


Expand All @@ -508,6 +518,7 @@ def process_all_successful_issues(
llm_config,
fork_owner,
False,
None,
)


Expand Down Expand Up @@ -573,6 +584,12 @@ def main():
default=None,
help='Base URL for the LLM model.',
)
parser.add_argument(
'--target-branch',
type=str,
default=None,
help='Target branch to create the pull request against (defaults to repository default branch)',
)
my_args = parser.parse_args()

github_token = (
Expand Down Expand Up @@ -625,6 +642,7 @@ def main():
llm_config,
my_args.fork_owner,
my_args.send_on_failure,
my_args.target_branch,
)


Expand Down
91 changes: 76 additions & 15 deletions tests/unit/resolver/test_send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,17 @@ def test_update_existing_pull_request(
)


@pytest.mark.parametrize('pr_type', ['branch', 'draft', 'ready'])
@pytest.mark.parametrize(
'pr_type,target_branch',
[
('branch', None),
('draft', None),
('ready', None),
('branch', 'feature'),
('draft', 'develop'),
('ready', 'staging'),
],
)
@patch('subprocess.run')
@patch('requests.post')
@patch('requests.get')
Expand All @@ -334,14 +344,22 @@ def test_send_pull_request(
mock_output_dir,
mock_llm_config,
pr_type,
target_branch,
):
repo_path = os.path.join(mock_output_dir, 'repo')

# Mock API responses
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(json=lambda: {'default_branch': 'main'}),
]
# Mock API responses based on whether target_branch is specified
if target_branch:
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(status_code=200), # Target branch exists
]
else:
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(json=lambda: {'default_branch': 'main'}), # Get default branch
]

mock_post.return_value.json.return_value = {
'html_url': 'https://github.com/test-owner/test-repo/pull/1'
}
Expand All @@ -360,10 +378,12 @@ def test_send_pull_request(
patch_dir=repo_path,
pr_type=pr_type,
llm_config=mock_llm_config,
target_branch=target_branch,
)

# Assert API calls
assert mock_get.call_count == 2
expected_get_calls = 2
assert mock_get.call_count == expected_get_calls

# Check branch creation and push
assert mock_run.call_count == 2
Expand Down Expand Up @@ -401,10 +421,41 @@ def test_send_pull_request(
assert post_data['title'] == 'Fix issue #42: Test Issue'
assert post_data['body'].startswith('This pull request fixes #42.')
assert post_data['head'] == 'openhands-fix-issue-42'
assert post_data['base'] == 'main'
assert post_data['base'] == (target_branch if target_branch else 'main')
assert post_data['draft'] == (pr_type == 'draft')


@patch('requests.get')
def test_send_pull_request_invalid_target_branch(
mock_get, mock_github_issue, mock_output_dir, mock_llm_config
):
"""Test that an error is raised when specifying a non-existent target branch"""
repo_path = os.path.join(mock_output_dir, 'repo')

# Mock API response for non-existent branch
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(status_code=404), # Target branch doesn't exist
]

# Test that ValueError is raised when target branch doesn't exist
with pytest.raises(
ValueError, match='Target branch nonexistent-branch does not exist'
):
send_pull_request(
github_issue=mock_github_issue,
github_token='test-token',
github_username='test-user',
patch_dir=repo_path,
pr_type='ready',
llm_config=mock_llm_config,
target_branch='nonexistent-branch',
)

# Verify API calls
assert mock_get.call_count == 2


@patch('subprocess.run')
@patch('requests.post')
@patch('requests.get')
Expand Down Expand Up @@ -616,6 +667,7 @@ def test_process_single_pr_update(
mock_llm_config,
None,
False,
None,
)

mock_initialize_repo.assert_called_once_with(mock_output_dir, 1, 'pr', 'branch 1')
Expand Down Expand Up @@ -688,6 +740,7 @@ def test_process_single_issue(
mock_llm_config,
None,
False,
None,
)

# Assert that the mocked functions were called with correct arguments
Expand All @@ -704,9 +757,10 @@ def test_process_single_issue(
github_username=github_username,
patch_dir=f'{mock_output_dir}/patches/issue_1',
pr_type=pr_type,
llm_config=mock_llm_config,
fork_owner=None,
additional_message=resolver_output.success_explanation,
llm_config=mock_llm_config,
target_branch=None,
)


Expand Down Expand Up @@ -757,6 +811,7 @@ def test_process_single_issue_unsuccessful(
mock_llm_config,
None,
False,
None,
)

# Assert that none of the mocked functions were called
Expand Down Expand Up @@ -863,6 +918,7 @@ def test_process_all_successful_issues(
mock_llm_config,
None,
False,
None,
),
call(
'output_dir',
Expand All @@ -873,6 +929,7 @@ def test_process_all_successful_issues(
mock_llm_config,
None,
False,
None,
),
]
)
Expand Down Expand Up @@ -971,6 +1028,7 @@ def test_main(
mock_args.llm_model = 'mock_model'
mock_args.llm_base_url = 'mock_url'
mock_args.llm_api_key = 'mock_key'
mock_args.target_branch = None
mock_parser.return_value.parse_args.return_value = mock_args

# Setup environment variables
Expand All @@ -994,12 +1052,8 @@ def test_main(
api_key=mock_args.llm_api_key,
)

# Assert function calls
mock_parser.assert_called_once()
mock_getenv.assert_any_call('GITHUB_TOKEN')
mock_path_exists.assert_called_with('/mock/output')
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
mock_process_single_issue.assert_called_with(
# Use any_call instead of assert_called_with for more flexible matching
assert mock_process_single_issue.call_args == call(
'/mock/output',
mock_resolver_output,
'mock_token',
Expand All @@ -1008,8 +1062,15 @@ def test_main(
llm_config,
None,
False,
mock_args.target_branch,
)

# Other assertions
mock_parser.assert_called_once()
mock_getenv.assert_any_call('GITHUB_TOKEN')
mock_path_exists.assert_called_with('/mock/output')
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)

# Test for 'all_successful' issue number
mock_args.issue_number = 'all_successful'
main()
Expand Down
Loading