From fba521ecc6d4c260de8f3e93dadbc3e157a9a9a6 Mon Sep 17 00:00:00 2001 From: Justin Newberry Date: Tue, 2 Jan 2024 15:42:08 -0500 Subject: [PATCH] url_file: fix non-200 files being cached (#30886) * fix + test * fix unclosed * easier to read Co-authored-by: Shane Smiskol * fix that --------- Co-authored-by: Shane Smiskol --- tools/lib/tests/test_caching.py | 79 +++++++++++++++++++++++---------- tools/lib/url_file.py | 2 + 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/tools/lib/tests/test_caching.py b/tools/lib/tests/test_caching.py index 73ed843869e757..294c5a223388be 100755 --- a/tools/lib/tests/test_caching.py +++ b/tools/lib/tests/test_caching.py @@ -1,15 +1,58 @@ #!/usr/bin/env python3 +from functools import wraps +import http.server import os +import threading +import time import unittest -from pathlib import Path from parameterized import parameterized -from unittest import mock -from openpilot.system.hardware.hw import Paths from openpilot.tools.lib.url_file import URLFile +class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler): + FILE_EXISTS = True + + def do_GET(self): + if self.FILE_EXISTS: + self.send_response(200, b'1234') + else: + self.send_response(404) + self.end_headers() + + def do_HEAD(self): + if self.FILE_EXISTS: + self.send_response(200) + self.send_header("Content-Length", "4") + else: + self.send_response(404) + self.end_headers() + + +class CachingTestServer(threading.Thread): + def run(self): + self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler) + self.port = self.server.server_port + self.server.serve_forever() + + def stop(self): + self.server.server_close() + self.server.shutdown() + +def with_caching_server(func): + @wraps(func) + def wrapper(*args, **kwargs): + server = CachingTestServer() + server.start() + time.sleep(0.25) # wait for server to get it's port + try: + func(*args, **kwargs, port=server.port) + finally: + server.stop() + return wrapper + + class TestFileDownload(unittest.TestCase): def compare_loads(self, url, start=0, length=None): @@ -66,32 +109,20 @@ def test_large_file(self): self.compare_loads(large_file_url) @parameterized.expand([(True, ), (False, )]) - def test_recover_from_missing_file(self, cache_enabled): + @with_caching_server + def test_recover_from_missing_file(self, cache_enabled, port): os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" - file_url = "http://localhost:5001/test.png" + file_url = f"http://localhost:{port}/test.png" - file_exists = False + CachingTestRequestHandler.FILE_EXISTS = False + length = URLFile(file_url).get_length() + self.assertEqual(length, -1) - def get_length_online_mock(self): - if file_exists: - return 4 - return -1 + CachingTestRequestHandler.FILE_EXISTS = True + length = URLFile(file_url).get_length() + self.assertEqual(length, 4) - patch_length = mock.patch.object(URLFile, "get_length_online", get_length_online_mock) - patch_length.start() - try: - length = URLFile(file_url).get_length() - self.assertEqual(length, -1) - - file_exists = True - length = URLFile(file_url).get_length() - self.assertEqual(length, 4) - finally: - tempfile_length = Path(Paths.download_cache_root()) / "ba2119904385654cb0105a2da174875f8e7648db175f202ecae6d6428b0e838f_length" - if tempfile_length.exists(): - tempfile_length.unlink() - patch_length.stop() if __name__ == "__main__": diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index d055f86577579a..97c0a639a79ae0 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -57,6 +57,8 @@ def __exit__(self, exc_type, exc_value, traceback): def get_length_online(self): timeout = Timeout(connect=50.0, read=500.0) response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False) + if not (200 <= response.status <= 299): + return -1 length = response.headers.get('content-length', 0) return int(length)