diff --git a/pull-request.py b/pull-request.py index ff569d0..f95ad1d 100755 --- a/pull-request.py +++ b/pull-request.py @@ -234,6 +234,22 @@ def find_pull_request(listing, source): return entry +def find_default_branch(): + """Find default branch for a repo (only called if branch not provided) + """ + response = requests.get(REPO_URL) + + # Case 1: 404 might need a token + if response.status_code == 404: + response = requests.get(REPO_URL, headers=HEADERS) + if response.status_code != 200: + abort_if_fail(response, "Unable to retrieve default branch") + + default_branch = response.json()["default_branch"] + print("Found default branch: %s" % default_branch) + return default_branch + + def add_reviewers(entry, reviewers, team_reviewers): """Given regular or team reviewers, add them to a PR. @@ -343,8 +359,8 @@ def main(): if not branch_prefix: print("No branch prefix is set, all branches will be used.") - # Default to master to support older, will eventually change to main - pull_request_branch = os.environ.get("PULL_REQUEST_BRANCH", "master") + # Default to project default branch if none provided + pull_request_branch = os.environ.get("PULL_REQUEST_BRANCH", find_default_branch()) print("Pull requests will go to %s" % pull_request_branch) # Pull request draft