Skip to content

Commit

Permalink
Stream the pgns
Browse files Browse the repository at this point in the history
  • Loading branch information
ppigazzini committed Apr 25, 2024
1 parent e5e1ae7 commit 2eb47e5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
19 changes: 9 additions & 10 deletions server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
HTTPUnauthorized,
exception_response,
)
from pyramid.response import Response
from pyramid.response import FileIter, Response
from pyramid.view import exception_view_config, view_config, view_defaults
from vtjson import ValidationError, validate

Expand Down Expand Up @@ -532,19 +532,18 @@ def download_pgn(self):

@view_config(route_name="api_download_run_pgns")
def download_run_pgns(self):
tar_name = self.request.matchdict["id"]
match = re.match(r"^([a-zA-Z0-9]+)\.pgns\.tar$", tar_name)
pgns_name = self.request.matchdict["id"]
match = re.match(r"^([a-zA-Z0-9]+)\.pgn\.gz$", pgns_name)
if not match:
return Response("Invalid filename format", status=400)
run_id = match.group(1)
pgns_zip = self.request.rundb.get_run_pgns(run_id)
if pgns_zip is None:
pgns_reader = self.request.rundb.get_run_pgns(run_id)
if pgns_reader is None:
return Response("No data found", status=404)
zip_buffer = io.BytesIO(pgns_zip)
response = Response(content_type="application/x-tar")
response.app_iter = zip_buffer
response.content_length = zip_buffer.getbuffer().nbytes
response.headers["Content-Disposition"] = f'attachment; filename="{tar_name}"'
response = Response(content_type="application/gzip")
response.app_iter = FileIter(pgns_reader)
response.headers["Content-Disposition"] = f'attachment; filename="{pgns_name}"'
response.headers["Content-Encoding"] = "gzip"
return response

@view_config(route_name="api_download_nn")
Expand Down
37 changes: 24 additions & 13 deletions server/fishtest/rundb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import configparser
import copy
import io
import math
import os
import random
import re
import signal
import sys
import tarfile
import textwrap
import threading
import time
Expand Down Expand Up @@ -81,6 +79,24 @@ def get_port():
return -6


class GeneratorAsFileReader:
def __init__(self, generator):
self.generator = generator
self.buffer = b""

def read(self, size=-1):
while size < 0 or len(self.buffer) < size:
try:
self.buffer += next(self.generator)
except StopIteration:
break
result, self.buffer = self.buffer[:size], self.buffer[size:]
return result

def close(self):
pass # No cleanup needed, but method is required


class RunDb:
def __init__(self, db_name="fishtest_new"):
# MongoDB server is assumed to be on the same machine, if not user should
Expand Down Expand Up @@ -270,17 +286,12 @@ def get_pgn(self, run_id):
def get_run_pgns(self, run_id):
pgns = self.pgndb.find({"run_id": {"$regex": f"^{run_id}"}})
if pgns:
with io.BytesIO() as tar_buffer:
with tarfile.open(fileobj=tar_buffer, mode="w") as tarf:
for pgn in pgns:
pgn_zip = pgn["pgn_zip"]
tarinfo = tarfile.TarInfo(f"{pgn['run_id']}.pgn.gz")
tarinfo.size = len(pgn_zip)
# Extract and convert the 4 bytes starting at index 4
tarinfo.mtime = int.from_bytes(pgn_zip[4:8], byteorder="little")
tarf.addfile(tarinfo, io.BytesIO(pgn_zip))
pgns_tar = tar_buffer.getvalue()
return pgns_tar
# Create a generator that yields each pgn.gz file
def pgn_generator():
for pgn in pgns:
yield pgn["pgn_zip"]

return GeneratorAsFileReader(pgn_generator())
return None

def write_nn(self, net):
Expand Down

0 comments on commit 2eb47e5

Please sign in to comment.