Skip to content

Commit

Permalink
Only cache branch protection within the scope of the function (#4747)
Browse files Browse the repository at this point in the history
* Only cache branch protection within the scope of the function

* Move branch protection cache one level up

* Add comment explanation

Co-authored-by: sarayourfriend <[email protected]>

---------

Co-authored-by: sarayourfriend <[email protected]>
  • Loading branch information
AetherUnbound and sarayourfriend authored Aug 19, 2024
1 parent 1f28543 commit ec1dc6a
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions catalog/dags/maintenance/pr_review_reminders/pr_review_reminders.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,38 +173,43 @@ def base_repo_name(pr: dict):
return pr["base"]["repo"]["name"]


_BRANCH_PROTECTION_CACHE = defaultdict(dict)
def get_branch_protection(
gh: GitHubAPI, repo: str, branch_name: str, cache: dict
) -> dict:
if branch_name not in cache[repo]:
cache[repo][branch_name] = gh.get_branch_protection(repo, branch_name)

return cache[repo][branch_name]

def get_branch_protection(gh: GitHubAPI, repo: str, branch_name: str) -> dict:
if branch_name not in _BRANCH_PROTECTION_CACHE[repo]:
_BRANCH_PROTECTION_CACHE[repo][branch_name] = gh.get_branch_protection(
repo, branch_name
)

return _BRANCH_PROTECTION_CACHE[repo][branch_name]


def get_min_required_approvals(gh: GitHubAPI, pr: dict) -> int:
def get_min_required_approvals(
gh: GitHubAPI, pr: dict, branch_protection_cache: dict
) -> int:
repo = base_repo_name(pr)
branch_name = pr["base"]["ref"]

try:
branch_protection_rules = get_branch_protection(gh, repo, branch_name)
branch_protection_rules = get_branch_protection(
gh, repo, branch_name, branch_protection_cache
)
except HTTPError as e:
# If the base branch does not have protection rules, the request
# above will 404. In that case, fall back to the rules for `main`
# as a safe default.
if e.response is not None and e.response.status_code == 404:
branch_protection_rules = get_branch_protection(gh, repo, "main")
branch_protection_rules = get_branch_protection(
gh, repo, "main", branch_protection_cache
)
else:
raise e

if "required_pull_request_reviews" not in branch_protection_rules:
# This can happen in the rare case where a PR is multiple branches deep,
# e.g. it depends on a branch which depends on a branch which depends on main.
# In that case, default to the rules for `main` as a safe default.
branch_protection_rules = get_branch_protection(gh, repo, "main")
branch_protection_rules = get_branch_protection(
gh, repo, "main", branch_protection_cache
)

return branch_protection_rules["required_pull_request_reviews"][
"required_approving_review_count"
Expand All @@ -214,6 +219,12 @@ def get_min_required_approvals(gh: GitHubAPI, pr: dict) -> int:
@task(task_id="pr_review_reminder_operator")
def post_reminders(maintainers: set[str], github_pat: str, dry_run: bool):
gh = GitHubAPI(github_pat)
# Build a new cache for each DAG run so that changes to repository branch settings
# are reflected "by the next DAG run". Caching them at the run level also ensures
# that all evaluations for pings happen with the same settings, preventing changes
# during a run from causing PRs to receive different treatment, which could
# produce confusing results we might interpret as a bug rather than just a delay.
branch_protection_cache = defaultdict(dict)

open_prs = []
for repo in REPOSITORIES:
Expand Down Expand Up @@ -267,7 +278,9 @@ def post_reminders(maintainers: set[str], github_pat: str, dry_run: bool):
existing_reviews = gh.get_pull_reviews(base_repo_name(pr), pr["number"])

approved_reviews = [r for r in existing_reviews if r["state"] == "APPROVED"]
if len(approved_reviews) >= get_min_required_approvals(gh, pr):
if len(approved_reviews) >= get_min_required_approvals(
gh, pr, branch_protection_cache
):
# if PR already has sufficient reviews to be merged, do not ping
# the requested reviewers.
continue
Expand Down

0 comments on commit ec1dc6a

Please sign in to comment.