From 2abdfbf4be1c6bd86588c01b3983fa53468c79ae Mon Sep 17 00:00:00 2001 From: Raymond Xu Date: Mon, 18 Nov 2024 11:44:32 -0800 Subject: [PATCH 1/2] allow to merge ot a specific target branch instead of main --- openhands/resolver/send_pull_request.py | 32 +++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index eade7fcfc419..8a9d6118bedd 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -203,6 +203,7 @@ def send_pull_request( pr_type: str, fork_owner: str | None = None, additional_message: str | None = None, + target_branch: str | None = None, ) -> str: if pr_type not in ['branch', 'draft', 'ready']: raise ValueError(f'Invalid pr_type: {pr_type}') @@ -224,12 +225,19 @@ def send_pull_request( attempt += 1 branch_name = f'{base_branch_name}-try{attempt}' - # Get the default branch - print('Getting default branch...') - response = requests.get(f'{base_url}', headers=headers) - response.raise_for_status() - default_branch = response.json()['default_branch'] - print(f'Default branch: {default_branch}') + # Get the default branch or use specified target branch + print('Getting base branch...') + if target_branch: + base_branch = target_branch + # Verify the target branch exists + response = requests.get(f'{base_url}/branches/{target_branch}', headers=headers) + if response.status_code != 200: + raise ValueError(f'Target branch {target_branch} does not exist') + else: + response = requests.get(f'{base_url}', headers=headers) + response.raise_for_status() + base_branch = response.json()['default_branch'] + print(f'Base branch: {base_branch}') # Create and checkout the new branch print('Creating new branch...') @@ -279,7 +287,7 @@ def send_pull_request( 'title': pr_title, # No need to escape title for GitHub API 'body': pr_body, 'head': branch_name, - 'base': default_branch, + 'base': base_branch, 'draft': pr_type == 'draft', } response = requests.post(f'{base_url}/pulls', headers=headers, json=data) @@ -435,6 +443,7 @@ def process_single_issue( llm_config: LLMConfig, fork_owner: str | None, send_on_failure: bool, + target_branch: str | None = None, ) -> None: if not resolver_output.success and not send_on_failure: print( @@ -484,6 +493,7 @@ def process_single_issue( llm_config=llm_config, fork_owner=fork_owner, additional_message=resolver_output.success_explanation, + target_branch=target_branch, ) @@ -508,6 +518,7 @@ def process_all_successful_issues( llm_config, fork_owner, False, + None, ) @@ -573,6 +584,12 @@ def main(): default=None, help='Base URL for the LLM model.', ) + parser.add_argument( + '--target-branch', + type=str, + default=None, + help='Target branch to create the pull request against (defaults to repository default branch)', + ) my_args = parser.parse_args() github_token = ( @@ -625,6 +642,7 @@ def main(): llm_config, my_args.fork_owner, my_args.send_on_failure, + my_args.target_branch, ) From 87a978356cb4a330b39281d4aeb14f98e11635de Mon Sep 17 00:00:00 2001 From: Raymond Xu Date: Mon, 18 Nov 2024 23:33:09 -0800 Subject: [PATCH 2/2] fix test_send_pull_request.py to adopt the target branch arg --- tests/unit/resolver/test_send_pull_request.py | 91 ++++++++++++++++--- 1 file changed, 76 insertions(+), 15 deletions(-) diff --git a/tests/unit/resolver/test_send_pull_request.py b/tests/unit/resolver/test_send_pull_request.py index 951be1af006c..f83e2e97ec2f 100644 --- a/tests/unit/resolver/test_send_pull_request.py +++ b/tests/unit/resolver/test_send_pull_request.py @@ -322,7 +322,17 @@ def test_update_existing_pull_request( ) -@pytest.mark.parametrize('pr_type', ['branch', 'draft', 'ready']) +@pytest.mark.parametrize( + 'pr_type,target_branch', + [ + ('branch', None), + ('draft', None), + ('ready', None), + ('branch', 'feature'), + ('draft', 'develop'), + ('ready', 'staging'), + ], +) @patch('subprocess.run') @patch('requests.post') @patch('requests.get') @@ -334,14 +344,22 @@ def test_send_pull_request( mock_output_dir, mock_llm_config, pr_type, + target_branch, ): repo_path = os.path.join(mock_output_dir, 'repo') - # Mock API responses - mock_get.side_effect = [ - MagicMock(status_code=404), # Branch doesn't exist - MagicMock(json=lambda: {'default_branch': 'main'}), - ] + # Mock API responses based on whether target_branch is specified + if target_branch: + mock_get.side_effect = [ + MagicMock(status_code=404), # Branch doesn't exist + MagicMock(status_code=200), # Target branch exists + ] + else: + mock_get.side_effect = [ + MagicMock(status_code=404), # Branch doesn't exist + MagicMock(json=lambda: {'default_branch': 'main'}), # Get default branch + ] + mock_post.return_value.json.return_value = { 'html_url': 'https://github.com/test-owner/test-repo/pull/1' } @@ -360,10 +378,12 @@ def test_send_pull_request( patch_dir=repo_path, pr_type=pr_type, llm_config=mock_llm_config, + target_branch=target_branch, ) # Assert API calls - assert mock_get.call_count == 2 + expected_get_calls = 2 + assert mock_get.call_count == expected_get_calls # Check branch creation and push assert mock_run.call_count == 2 @@ -401,10 +421,41 @@ def test_send_pull_request( assert post_data['title'] == 'Fix issue #42: Test Issue' assert post_data['body'].startswith('This pull request fixes #42.') assert post_data['head'] == 'openhands-fix-issue-42' - assert post_data['base'] == 'main' + assert post_data['base'] == (target_branch if target_branch else 'main') assert post_data['draft'] == (pr_type == 'draft') +@patch('requests.get') +def test_send_pull_request_invalid_target_branch( + mock_get, mock_github_issue, mock_output_dir, mock_llm_config +): + """Test that an error is raised when specifying a non-existent target branch""" + repo_path = os.path.join(mock_output_dir, 'repo') + + # Mock API response for non-existent branch + mock_get.side_effect = [ + MagicMock(status_code=404), # Branch doesn't exist + MagicMock(status_code=404), # Target branch doesn't exist + ] + + # Test that ValueError is raised when target branch doesn't exist + with pytest.raises( + ValueError, match='Target branch nonexistent-branch does not exist' + ): + send_pull_request( + github_issue=mock_github_issue, + github_token='test-token', + github_username='test-user', + patch_dir=repo_path, + pr_type='ready', + llm_config=mock_llm_config, + target_branch='nonexistent-branch', + ) + + # Verify API calls + assert mock_get.call_count == 2 + + @patch('subprocess.run') @patch('requests.post') @patch('requests.get') @@ -616,6 +667,7 @@ def test_process_single_pr_update( mock_llm_config, None, False, + None, ) mock_initialize_repo.assert_called_once_with(mock_output_dir, 1, 'pr', 'branch 1') @@ -688,6 +740,7 @@ def test_process_single_issue( mock_llm_config, None, False, + None, ) # Assert that the mocked functions were called with correct arguments @@ -704,9 +757,10 @@ def test_process_single_issue( github_username=github_username, patch_dir=f'{mock_output_dir}/patches/issue_1', pr_type=pr_type, + llm_config=mock_llm_config, fork_owner=None, additional_message=resolver_output.success_explanation, - llm_config=mock_llm_config, + target_branch=None, ) @@ -757,6 +811,7 @@ def test_process_single_issue_unsuccessful( mock_llm_config, None, False, + None, ) # Assert that none of the mocked functions were called @@ -863,6 +918,7 @@ def test_process_all_successful_issues( mock_llm_config, None, False, + None, ), call( 'output_dir', @@ -873,6 +929,7 @@ def test_process_all_successful_issues( mock_llm_config, None, False, + None, ), ] ) @@ -971,6 +1028,7 @@ def test_main( mock_args.llm_model = 'mock_model' mock_args.llm_base_url = 'mock_url' mock_args.llm_api_key = 'mock_key' + mock_args.target_branch = None mock_parser.return_value.parse_args.return_value = mock_args # Setup environment variables @@ -994,12 +1052,8 @@ def test_main( api_key=mock_args.llm_api_key, ) - # Assert function calls - mock_parser.assert_called_once() - mock_getenv.assert_any_call('GITHUB_TOKEN') - mock_path_exists.assert_called_with('/mock/output') - mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42) - mock_process_single_issue.assert_called_with( + # Use any_call instead of assert_called_with for more flexible matching + assert mock_process_single_issue.call_args == call( '/mock/output', mock_resolver_output, 'mock_token', @@ -1008,8 +1062,15 @@ def test_main( llm_config, None, False, + mock_args.target_branch, ) + # Other assertions + mock_parser.assert_called_once() + mock_getenv.assert_any_call('GITHUB_TOKEN') + mock_path_exists.assert_called_with('/mock/output') + mock_load_single_resolver_output.assert_called_with('/mock/output/output.jsonl', 42) + # Test for 'all_successful' issue number mock_args.issue_number = 'all_successful' main()