Skip to content

Commit

Permalink
Allow to merge to a specific target branch instead of main (#5109)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryx2 authored Nov 19, 2024
1 parent ca64c69 commit 2c58038
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 22 deletions.
32 changes: 25 additions & 7 deletions openhands/resolver/send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def send_pull_request(
pr_type: str,
fork_owner: str | None = None,
additional_message: str | None = None,
target_branch: str | None = None,
) -> str:
if pr_type not in ['branch', 'draft', 'ready']:
raise ValueError(f'Invalid pr_type: {pr_type}')
Expand All @@ -224,12 +225,19 @@ def send_pull_request(
attempt += 1
branch_name = f'{base_branch_name}-try{attempt}'

# Get the default branch
print('Getting default branch...')
response = requests.get(f'{base_url}', headers=headers)
response.raise_for_status()
default_branch = response.json()['default_branch']
print(f'Default branch: {default_branch}')
# Get the default branch or use specified target branch
print('Getting base branch...')
if target_branch:
base_branch = target_branch
# Verify the target branch exists
response = requests.get(f'{base_url}/branches/{target_branch}', headers=headers)
if response.status_code != 200:
raise ValueError(f'Target branch {target_branch} does not exist')
else:
response = requests.get(f'{base_url}', headers=headers)
response.raise_for_status()
base_branch = response.json()['default_branch']
print(f'Base branch: {base_branch}')

# Create and checkout the new branch
print('Creating new branch...')
Expand Down Expand Up @@ -279,7 +287,7 @@ def send_pull_request(
'title': pr_title, # No need to escape title for GitHub API
'body': pr_body,
'head': branch_name,
'base': default_branch,
'base': base_branch,
'draft': pr_type == 'draft',
}
response = requests.post(f'{base_url}/pulls', headers=headers, json=data)
Expand Down Expand Up @@ -435,6 +443,7 @@ def process_single_issue(
llm_config: LLMConfig,
fork_owner: str | None,
send_on_failure: bool,
target_branch: str | None = None,
) -> None:
if not resolver_output.success and not send_on_failure:
print(
Expand Down Expand Up @@ -484,6 +493,7 @@ def process_single_issue(
llm_config=llm_config,
fork_owner=fork_owner,
additional_message=resolver_output.success_explanation,
target_branch=target_branch,
)


Expand All @@ -508,6 +518,7 @@ def process_all_successful_issues(
llm_config,
fork_owner,
False,
None,
)


Expand Down Expand Up @@ -573,6 +584,12 @@ def main():
default=None,
help='Base URL for the LLM model.',
)
parser.add_argument(
'--target-branch',
type=str,
default=None,
help='Target branch to create the pull request against (defaults to repository default branch)',
)
my_args = parser.parse_args()

github_token = (
Expand Down Expand Up @@ -625,6 +642,7 @@ def main():
llm_config,
my_args.fork_owner,
my_args.send_on_failure,
my_args.target_branch,
)


Expand Down
91 changes: 76 additions & 15 deletions tests/unit/resolver/test_send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,17 @@ def test_update_existing_pull_request(
)


@pytest.mark.parametrize('pr_type', ['branch', 'draft', 'ready'])
@pytest.mark.parametrize(
'pr_type,target_branch',
[
('branch', None),
('draft', None),
('ready', None),
('branch', 'feature'),
('draft', 'develop'),
('ready', 'staging'),
],
)
@patch('subprocess.run')
@patch('requests.post')
@patch('requests.get')
Expand All @@ -334,14 +344,22 @@ def test_send_pull_request(
mock_output_dir,
mock_llm_config,
pr_type,
target_branch,
):
repo_path = os.path.join(mock_output_dir, 'repo')

# Mock API responses
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(json=lambda: {'default_branch': 'main'}),
]
# Mock API responses based on whether target_branch is specified
if target_branch:
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(status_code=200), # Target branch exists
]
else:
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(json=lambda: {'default_branch': 'main'}), # Get default branch
]

mock_post.return_value.json.return_value = {
'html_url': 'https://github.com/test-owner/test-repo/pull/1'
}
Expand All @@ -360,10 +378,12 @@ def test_send_pull_request(
patch_dir=repo_path,
pr_type=pr_type,
llm_config=mock_llm_config,
target_branch=target_branch,
)

# Assert API calls
assert mock_get.call_count == 2
expected_get_calls = 2
assert mock_get.call_count == expected_get_calls

# Check branch creation and push
assert mock_run.call_count == 2
Expand Down Expand Up @@ -401,10 +421,41 @@ def test_send_pull_request(
assert post_data['title'] == 'Fix issue #42: Test Issue'
assert post_data['body'].startswith('This pull request fixes #42.')
assert post_data['head'] == 'openhands-fix-issue-42'
assert post_data['base'] == 'main'
assert post_data['base'] == (target_branch if target_branch else 'main')
assert post_data['draft'] == (pr_type == 'draft')


@patch('requests.get')
def test_send_pull_request_invalid_target_branch(
mock_get, mock_github_issue, mock_output_dir, mock_llm_config
):
"""Test that an error is raised when specifying a non-existent target branch"""
repo_path = os.path.join(mock_output_dir, 'repo')

# Mock API response for non-existent branch
mock_get.side_effect = [
MagicMock(status_code=404), # Branch doesn't exist
MagicMock(status_code=404), # Target branch doesn't exist
]

# Test that ValueError is raised when target branch doesn't exist
with pytest.raises(
ValueError, match='Target branch nonexistent-branch does not exist'
):
send_pull_request(
github_issue=mock_github_issue,
github_token='test-token',
github_username='test-user',
patch_dir=repo_path,
pr_type='ready',
llm_config=mock_llm_config,
target_branch='nonexistent-branch',
)

# Verify API calls
assert mock_get.call_count == 2


@patch('subprocess.run')
@patch('requests.post')
@patch('requests.get')
Expand Down Expand Up @@ -616,6 +667,7 @@ def test_process_single_pr_update(
mock_llm_config,
None,
False,
None,
)

mock_initialize_repo.assert_called_once_with(mock_output_dir, 1, 'pr', 'branch 1')
Expand Down Expand Up @@ -688,6 +740,7 @@ def test_process_single_issue(
mock_llm_config,
None,
False,
None,
)

# Assert that the mocked functions were called with correct arguments
Expand All @@ -704,9 +757,10 @@ def test_process_single_issue(
github_username=github_username,
patch_dir=f'{mock_output_dir}/patches/issue_1',
pr_type=pr_type,
llm_config=mock_llm_config,
fork_owner=None,
additional_message=resolver_output.success_explanation,
llm_config=mock_llm_config,
target_branch=None,
)


Expand Down Expand Up @@ -757,6 +811,7 @@ def test_process_single_issue_unsuccessful(
mock_llm_config,
None,
False,
None,
)

# Assert that none of the mocked functions were called
Expand Down Expand Up @@ -863,6 +918,7 @@ def test_process_all_successful_issues(
mock_llm_config,
None,
False,
None,
),
call(
'output_dir',
Expand All @@ -873,6 +929,7 @@ def test_process_all_successful_issues(
mock_llm_config,
None,
False,
None,
),
]
)
Expand Down Expand Up @@ -971,6 +1028,7 @@ def test_main(
mock_args.llm_model = 'mock_model'
mock_args.llm_base_url = 'mock_url'
mock_args.llm_api_key = 'mock_key'
mock_args.target_branch = None
mock_parser.return_value.parse_args.return_value = mock_args

# Setup environment variables
Expand All @@ -994,12 +1052,8 @@ def test_main(
api_key=mock_args.llm_api_key,
)

# Assert function calls
mock_parser.assert_called_once()
mock_getenv.assert_any_call('GITHUB_TOKEN')
mock_path_exists.assert_called_with('/mock/output')
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)
mock_process_single_issue.assert_called_with(
# Use any_call instead of assert_called_with for more flexible matching
assert mock_process_single_issue.call_args == call(
'/mock/output',
mock_resolver_output,
'mock_token',
Expand All @@ -1008,8 +1062,15 @@ def test_main(
llm_config,
None,
False,
mock_args.target_branch,
)

# Other assertions
mock_parser.assert_called_once()
mock_getenv.assert_any_call('GITHUB_TOKEN')
mock_path_exists.assert_called_with('/mock/output')
mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42)

# Test for 'all_successful' issue number
mock_args.issue_number = 'all_successful'
main()
Expand Down

0 comments on commit 2c58038

Please sign in to comment.