|
| 1 | +"""RestructuredText changelog generator.""" |
| 2 | + |
| 3 | +from collections import defaultdict |
| 4 | +import os |
| 5 | + |
| 6 | +HEADERS = { |
| 7 | + "Accept": "application/vnd.github.v3+json", |
| 8 | +} |
| 9 | + |
| 10 | +if os.getenv("GITHUB_TOKEN") is not None: |
| 11 | + HEADERS["Authorization"] = f"token {os.getenv('GITHUB_TOKEN')}" |
| 12 | + |
| 13 | +OWNER = "jdb78" |
| 14 | +REPO = "pytorch-forecasting" |
| 15 | +GITHUB_REPOS = "https://api.github.com/repos" |
| 16 | + |
| 17 | + |
| 18 | +def fetch_merged_pull_requests(page: int = 1) -> list[dict]: |
| 19 | + """Fetch a page of merged pull requests. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + page : int, optional |
| 24 | + Page number to fetch, by default 1. |
| 25 | + Returns all merged pull request from the ``page``-th page of closed PRs, |
| 26 | + where pages are in descending order of last update. |
| 27 | +
|
| 28 | + Returns |
| 29 | + ------- |
| 30 | + list |
| 31 | + List of merged pull requests from the ``page``-th page of closed PRs. |
| 32 | + Elements of list are dictionaries with PR details, as obtained |
| 33 | + from the GitHub API via ``httpx.get``, from the ``pulls`` endpoint. |
| 34 | + """ |
| 35 | + import httpx |
| 36 | + |
| 37 | + params = { |
| 38 | + "base": "main", |
| 39 | + "state": "closed", |
| 40 | + "page": page, |
| 41 | + "per_page": 50, |
| 42 | + "sort": "updated", |
| 43 | + "direction": "desc", |
| 44 | + } |
| 45 | + r = httpx.get( |
| 46 | + f"{GITHUB_REPOS}/{OWNER}/{REPO}/pulls", |
| 47 | + headers=HEADERS, |
| 48 | + params=params, |
| 49 | + ) |
| 50 | + return [pr for pr in r.json() if pr["merged_at"]] |
| 51 | + |
| 52 | + |
| 53 | +def fetch_latest_release(): # noqa: D103 |
| 54 | + """Fetch the latest release from the GitHub API. |
| 55 | +
|
| 56 | + Returns |
| 57 | + ------- |
| 58 | + dict |
| 59 | + Dictionary with details of the latest release. |
| 60 | + Dictionary is as obtained from the GitHub API via ``httpx.get``, |
| 61 | + for ``releases/latest`` endpoint. |
| 62 | + """ |
| 63 | + import httpx |
| 64 | + |
| 65 | + response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS) |
| 66 | + |
| 67 | + if response.status_code == 200: |
| 68 | + return response.json() |
| 69 | + else: |
| 70 | + raise ValueError(response.text, response.status_code) |
| 71 | + |
| 72 | + |
| 73 | +def fetch_pull_requests_since_last_release() -> list[dict]: |
| 74 | + """Fetch all pull requests merged since last release. |
| 75 | +
|
| 76 | + Returns |
| 77 | + ------- |
| 78 | + list |
| 79 | + List of pull requests merged since the latest release. |
| 80 | + Elements of list are dictionaries with PR details, as obtained |
| 81 | + from the GitHub API via ``httpx.get``, through ``fetch_merged_pull_requests``. |
| 82 | + """ |
| 83 | + from dateutil import parser |
| 84 | + |
| 85 | + release = fetch_latest_release() |
| 86 | + published_at = parser.parse(release["published_at"]) |
| 87 | + print(f"Latest release {release['tag_name']} was published at {published_at}") |
| 88 | + |
| 89 | + is_exhausted = False |
| 90 | + page = 1 |
| 91 | + all_pulls = [] |
| 92 | + while not is_exhausted: |
| 93 | + pulls = fetch_merged_pull_requests(page=page) |
| 94 | + all_pulls.extend([p for p in pulls if parser.parse(p["merged_at"]) > published_at]) |
| 95 | + is_exhausted = any(parser.parse(p["updated_at"]) < published_at for p in pulls) |
| 96 | + page += 1 |
| 97 | + return all_pulls |
| 98 | + |
| 99 | + |
| 100 | +def github_compare_tags(tag_left: str, tag_right: str = "HEAD"): |
| 101 | + """Compare commit between two tags.""" |
| 102 | + import httpx |
| 103 | + |
| 104 | + response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}") |
| 105 | + if response.status_code == 200: |
| 106 | + return response.json() |
| 107 | + else: |
| 108 | + raise ValueError(response.text, response.status_code) |
| 109 | + |
| 110 | + |
| 111 | +def render_contributors(prs: list, fmt: str = "rst"): |
| 112 | + """Find unique authors and print a list in given format.""" |
| 113 | + authors = sorted({pr["user"]["login"] for pr in prs}, key=lambda x: x.lower()) |
| 114 | + |
| 115 | + header = "Contributors" |
| 116 | + if fmt == "github": |
| 117 | + print(f"### {header}") |
| 118 | + print(", ".join(f"@{user}" for user in authors)) |
| 119 | + elif fmt == "rst": |
| 120 | + print(header) |
| 121 | + print("~" * len(header), end="\n\n") |
| 122 | + print(",\n".join(f":user:`{user}`" for user in authors)) |
| 123 | + |
| 124 | + |
| 125 | +def assign_prs(prs, categs: list[dict[str, list[str]]]): |
| 126 | + """Assign PR to categories based on labels.""" |
| 127 | + assigned = defaultdict(list) |
| 128 | + |
| 129 | + for i, pr in enumerate(prs): |
| 130 | + for cat in categs: |
| 131 | + pr_labels = [label["name"] for label in pr["labels"]] |
| 132 | + if not set(cat["labels"]).isdisjoint(set(pr_labels)): |
| 133 | + assigned[cat["title"]].append(i) |
| 134 | + |
| 135 | + # if any(l.startswith("module") for l in pr_labels): |
| 136 | + # print(i, pr_labels) |
| 137 | + |
| 138 | + assigned["Other"] = list(set(range(len(prs))) - {i for _, j in assigned.items() for i in j}) |
| 139 | + |
| 140 | + return assigned |
| 141 | + |
| 142 | + |
| 143 | +def render_row(pr): |
| 144 | + """Render a single row with PR in restructuredText format.""" |
| 145 | + print( |
| 146 | + "*", |
| 147 | + pr["title"].replace("`", "``"), |
| 148 | + f"(:pr:`{pr['number']}`)", |
| 149 | + f":user:`{pr['user']['login']}`", |
| 150 | + ) |
| 151 | + |
| 152 | + |
| 153 | +def render_changelog(prs, assigned): |
| 154 | + # sourcery skip: use-named-expression |
| 155 | + """Render changelog.""" |
| 156 | + from dateutil import parser |
| 157 | + |
| 158 | + for title, _ in assigned.items(): |
| 159 | + pr_group = [prs[i] for i in assigned[title]] |
| 160 | + if pr_group: |
| 161 | + print(f"\n{title}") |
| 162 | + print("~" * len(title), end="\n\n") |
| 163 | + |
| 164 | + for pr in sorted(pr_group, key=lambda x: parser.parse(x["merged_at"])): |
| 165 | + render_row(pr) |
| 166 | + |
| 167 | + |
| 168 | +if __name__ == "__main__": |
| 169 | + categories = [ |
| 170 | + {"title": "Enhancements", "labels": ["feature", "enhancement"]}, |
| 171 | + {"title": "Fixes", "labels": ["bug", "fix", "bugfix"]}, |
| 172 | + {"title": "Maintenance", "labels": ["maintenance", "chore"]}, |
| 173 | + {"title": "Refactored", "labels": ["refactor"]}, |
| 174 | + {"title": "Documentation", "labels": ["documentation"]}, |
| 175 | + ] |
| 176 | + |
| 177 | + pulls = fetch_pull_requests_since_last_release() |
| 178 | + print(f"Found {len(pulls)} merged PRs since last release") |
| 179 | + assigned = assign_prs(pulls, categories) |
| 180 | + render_changelog(pulls, assigned) |
| 181 | + print() |
| 182 | + render_contributors(pulls) |
| 183 | + |
| 184 | + release = fetch_latest_release() |
| 185 | + diff = github_compare_tags(release["tag_name"]) |
| 186 | + if diff["total_commits"] != len(pulls): |
| 187 | + raise ValueError( |
| 188 | + "Something went wrong and not all PR were fetched. " |
| 189 | + f'There are {len(pulls)} PRs but {diff["total_commits"]} in the diff. ' |
| 190 | + "Please verify that all PRs are included in the changelog." |
| 191 | + ) |
0 commit comments