diff --git a/src/herald/github.py b/src/herald/github.py index 3ccfc7f..0255426 100644 --- a/src/herald/github.py +++ b/src/herald/github.py @@ -102,7 +102,9 @@ def __init__(self) -> None: ) @artifact_download_time.time() - def _download_artifact(self, token: str, repo: str, artifact_id: int) -> bytes: + def _download_artifact( + self, token: str, repo: str, artifact_id: int, fh: IO[bytes] + ): logger.info("Downloading artifact %d from GitHub", artifact_id) github_api_call_count.labels(type="artifact_info").inc() r = requests.get( @@ -121,27 +123,28 @@ def _download_artifact(self, token: str, repo: str, artifact_id: int) -> bytes: raise ArtifactTooLarge(artifact_id, data["size_in_bytes"], repo=repo) github_api_call_count.labels(type="artifact_download").inc() - r = requests.get( + with requests.get( f"https://api.github.com/repos/{repo}/actions/artifacts/{artifact_id}/zip", headers={"Authorization": f"Bearer {token}"}, - ) - - if r.status_code == 410: - logger.info("Artifact %d has expired", artifact_id) - raise ArtifactExpired(artifact_id, repo=repo) - if r.status_code == 404: - logger.info("Artifact %d has does not exist", artifact_id) - raise ArtifactNotFound(artifact_id, repo=repo) - try: - r.raise_for_status() - except Exception as e: - logger.info( - "Got HTTP error for downloading artifact %d", artifact_id, exc_info=True - ) - raise e + ) as r: + if r.status_code == 410: + logger.info("Artifact %d has expired", artifact_id) + raise ArtifactExpired(artifact_id, repo=repo) + if r.status_code == 404: + logger.info("Artifact %d has does not exist", artifact_id) + raise ArtifactNotFound(artifact_id, repo=repo) + try: + r.raise_for_status() + except Exception as e: + logger.info( + "Got HTTP error for downloading artifact %d", + artifact_id, + exc_info=True, + ) + raise e + for chunk in r.iter_content(chunk_size=8192): + fh.write(chunk) logger.info("Download of artifact %d complete", artifact_id) - # @TODO: stream to temporary file - return r.content @contextlib.contextmanager def get_artifact(self, token: str, repo: str, artifact_id: int): @@ -171,19 +174,23 @@ def get_artifact(self, token: str, repo: str, artifact_id: int): logger.info("Culling artifact cache") self._artifact_cache.cull() logger.info("Cull complete") - buffer = self._download_artifact(token, repo, artifact_id) - logger.info( - "Have buffer of size %d for artifact %d, writing to key %s", - len(buffer), - artifact_id, - key, - ) # buffer is zip, let's make a tarball out of it - with tempfile.TemporaryDirectory() as tmpd: - z = zipfile.ZipFile(io.BytesIO(buffer)) - z.extractall(tmpd) + with tempfile.NamedTemporaryFile("wb+") as tar_fh: + with ( + tempfile.TemporaryFile("wb+") as zip_fh, + tempfile.TemporaryDirectory() as tmpd, + ): + self._download_artifact(token, repo, artifact_id, zip_fh) + zip_fh.seek(0) + logger.info( + "Download of artifact %d complete, writing to key %s", + artifact_id, + key, + ) + + z = zipfile.ZipFile(zip_fh) + z.extractall(tmpd) - with tempfile.NamedTemporaryFile("wb+") as tar_fh: t = tarfile.TarFile(fileobj=tar_fh, mode="w") t.add(tmpd, arcname=".", recursive=True) t.close() @@ -191,21 +198,21 @@ def get_artifact(self, token: str, repo: str, artifact_id: int): tar_fh.flush() tar_fh.seek(0) - compressor = zstandard.ZstdCompressor() - with self._artifact_cache.open(key, "wb") as fh: - compressor.copy_stream(tar_fh, fh) + compressor = zstandard.ZstdCompressor() + with self._artifact_cache.open(key, "wb") as fh: + compressor.copy_stream(tar_fh, fh) - logger.info( - "Cache reports key %s created for artifact %d", + logger.info( + "Cache reports key %s created for artifact %d", + key, + artifact_id, + ) + if key not in self._artifact_cache: + logger.error( + "Key %s did not get set!", key, - artifact_id, ) - if key not in self._artifact_cache: - logger.error( - "Key %s did not get set!", - key, - ) - raise RuntimeError("Key not written to cache") + raise RuntimeError("Key not written to cache") # return self._artifact_cache[key] with self._artifact_cache.open(key) as fh: