diff --git a/server/fishtest/api.py b/server/fishtest/api.py index 9787c45904..25c6c9ec0a 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -13,7 +13,7 @@ HTTPUnauthorized, exception_response, ) -from pyramid.response import Response +from pyramid.response import FileIter, Response from pyramid.view import exception_view_config, view_config, view_defaults from vtjson import ValidationError, validate @@ -532,19 +532,18 @@ def download_pgn(self): @view_config(route_name="api_download_run_pgns") def download_run_pgns(self): - tar_name = self.request.matchdict["id"] - match = re.match(r"^([a-zA-Z0-9]+)\.pgns\.tar$", tar_name) + pgns_name = self.request.matchdict["id"] + match = re.match(r"^([a-zA-Z0-9]+)\.pgn\.gz$", pgns_name) if not match: return Response("Invalid filename format", status=400) run_id = match.group(1) - pgns_zip = self.request.rundb.get_run_pgns(run_id) - if pgns_zip is None: + pgns_reader = self.request.rundb.get_run_pgns(run_id) + if pgns_reader is None: return Response("No data found", status=404) - zip_buffer = io.BytesIO(pgns_zip) - response = Response(content_type="application/x-tar") - response.app_iter = zip_buffer - response.content_length = zip_buffer.getbuffer().nbytes - response.headers["Content-Disposition"] = f'attachment; filename="{tar_name}"' + response = Response(content_type="application/gzip") + response.app_iter = FileIter(pgns_reader) + response.headers["Content-Disposition"] = f'attachment; filename="{pgns_name}"' + response.headers["Content-Encoding"] = "gzip" return response @view_config(route_name="api_download_nn") diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index b9ed94b506..33a89b78b0 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -1,13 +1,11 @@ import configparser import copy -import io import math import os import random import re import signal import sys -import tarfile import textwrap import threading import time @@ -22,6 +20,7 @@ from fishtest.stats.stat_util import SPRT_elo from fishtest.userdb import UserDb from fishtest.util import ( + GeneratorAsFileReader, crash_or_time, estimate_game_duration, format_bounds, @@ -269,18 +268,10 @@ def get_pgn(self, run_id): def get_run_pgns(self, run_id): pgns = self.pgndb.find({"run_id": {"$regex": f"^{run_id}"}}) - if pgns: - with io.BytesIO() as tar_buffer: - with tarfile.open(fileobj=tar_buffer, mode="w") as tarf: - for pgn in pgns: - pgn_zip = pgn["pgn_zip"] - tarinfo = tarfile.TarInfo(f"{pgn['run_id']}.pgn.gz") - tarinfo.size = len(pgn_zip) - # Extract and convert the 4 bytes starting at index 4 - tarinfo.mtime = int.from_bytes(pgn_zip[4:8], byteorder="little") - tarf.addfile(tarinfo, io.BytesIO(pgn_zip)) - pgns_tar = tar_buffer.getvalue() - return pgns_tar + if pgns is not None: + # Create a generator that yields each pgn.gz file + pgn_generator = (pgn["pgn_zip"] for pgn in pgns) + return GeneratorAsFileReader(pgn_generator) return None def write_nn(self, net): diff --git a/server/fishtest/util.py b/server/fishtest/util.py index 730316b96e..b7ea591af4 100644 --- a/server/fishtest/util.py +++ b/server/fishtest/util.py @@ -15,6 +15,24 @@ FISH_URL = "https://tests.stockfishchess.org/tests/view/" +class GeneratorAsFileReader: + def __init__(self, generator): + self.generator = generator + self.buffer = b"" + + def read(self, size=-1): + while size < 0 or len(self.buffer) < size: + try: + self.buffer += next(self.generator) + except StopIteration: + break + result, self.buffer = self.buffer[:size], self.buffer[size:] + return result + + def close(self): + pass # No cleanup needed, but method is required + + def hex_print(s): return hashlib.md5(str(s).encode("utf-8")).digest().hex()