diff --git a/.restyled.yaml b/.restyled.yaml new file mode 100644 index 0000000..9ccfcfb --- /dev/null +++ b/.restyled.yaml @@ -0,0 +1,5 @@ +--- +restylers: + - pyment: + enabled: false + - "*" diff --git a/tools/hub-restyled b/tools/hub-restyled new file mode 100755 index 0000000..ebf88af --- /dev/null +++ b/tools/hub-restyled @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +import http.client +import io +import json +import os +import subprocess +import urllib.parse +import urllib.request +import zipfile +from typing import Optional + +token = os.environ["GITHUB_TOKEN"] + + +def get_remotes() -> dict[str, str]: + """Get the remotes of the current git repository.""" + lines = (subprocess.check_output([ + "git", + "remote", + "-v", + ]).decode("utf-8").split("\n")) + + remotes = {} + for line in lines: + if not line: + continue + remote, url = line.split("\t", 1) + remotes[remote] = url.split(" ")[0] + + return remotes + + +def get_upstream() -> Optional[str]: + """Get the upstream remote URL.""" + remotes = get_remotes() + if "upstream" in remotes: + return remotes["upstream"] + return None + + +def get_origin() -> Optional[str]: + """Get the origin remote URL.""" + remotes = get_remotes() + if "origin" in remotes: + return remotes["origin"] + return None + + +def get_slug() -> Optional[str]: + """Get the GitHub slug of the current repository.""" + upstream = get_upstream() + if not upstream: + return None + if "git@github.com:" in upstream: + return upstream.split(":")[-1].split(".git")[0] + if "https" in upstream: + return upstream.split("github.com/")[-1].split(".git")[0] + return None + + +def get_github_user() -> Optional[str]: + """Get the GitHub user of the origin repository.""" + origin = get_origin() + if not origin: + return None + return origin.split("/")[0] + + +def get_current_branch() -> str: + """Get the current git branch name.""" + return (subprocess.check_output([ + "git", + "rev-parse", + "--abbrev-ref", + "HEAD", + ]).decode("utf-8").strip()) + + +def get_head_sha() -> str: + """Get the HEAD SHA of the current git branch.""" + return (subprocess.check_output([ + "git", + "rev-parse", + "HEAD", + ]).decode("utf-8").strip()) + + +def get(url: str) -> dict: + """GET and parse JSON from a URL.""" + req = urllib.request.Request( + url, + headers={ + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {token}", + "X-GitHub-Api-Version": "2022-11-28", + }, + ) + with urllib.request.urlopen(req) as f: + return json.loads(f.read().decode("utf-8")) + + +def download_redirect(url: Optional[str]) -> Optional[bytes]: + """Recursively follow redirects until we get the final URL.""" + if not url: + print("Error: no download URL") + return None + host = urllib.parse.urlparse(url).netloc + h = http.client.HTTPSConnection(host) + h.request( + "GET", + url, + headers={ + "User-Agent": "Mozilla/5.0", + }, + ) + resp = h.getresponse() + if resp.status == 301 or resp.status == 302: + return download_redirect(resp.getheader("Location")) + if resp.status != 200: + print(f"Error downloading: {resp.status}") + return None + return resp.read() + + +def download(url: str) -> Optional[bytes]: + """Download a file from a URL.""" + host = urllib.parse.urlparse(url).netloc + h = http.client.HTTPSConnection(host) + h.request( + "GET", + url, + headers={ + "User-Agent": "Mozilla/5.0", + "Accept": "application/vnd.github.v3+json", + "Authorization": f"Bearer {token}", + "X-GitHub-Api-Version": "2022-11-28", + }, + ) + resp = h.getresponse() + if resp.status != 302: + print( + f'Error fetching download URL: {resp.status} {resp.reason} {resp.read().decode("utf-8")}' + ) + return None + return download_redirect(resp.getheader("Location")) + + +def main() -> None: + """Main entry point.""" + data = get(f"https://api.github.com/repos/{get_slug()}/actions/runs" + f"?branch={get_current_branch()}" + f"&event=pull_request" + f"&head_sha={get_head_sha()}") + # pretty-print + for run in data["workflow_runs"]: + artifacts = get(run["artifacts_url"]) + if artifacts["total_count"] == 1: + zip = download(artifacts["artifacts"][0]["archive_download_url"]) + if not zip: + return + zip_buffer = io.BytesIO(zip) + with zipfile.ZipFile(zip_buffer) as z: + # run patch -p1 < restyled.diff + diff = z.read("restyled.diff") + subprocess.run(["patch", "-p1"], input=diff) + return + print("No restyled.diff found") + + +if __name__ == "__main__": + main()