From f8b9c2863f69b881d6a513b841a60846fcab2a60 Mon Sep 17 00:00:00 2001 From: Michel Van den Bergh Date: Tue, 5 Dec 2023 20:25:15 +0000 Subject: [PATCH] Tighten up the schema in api.py. The api's in api.py are a direct connection between the workers and the db. So it makes sense to be very strict in what we accept. Note: a serious bug was discovered in the handling of optional keys in vtjson. So vtjson should be upgraded to the latest version. Since upgrading is necessary anyway, this PR depends on a previously internal feature of vtjson, namely the possibility to pre-compile a schema. Validating an update_task api call for an SPRT run takes 0.026ms on an AWS entry level t2.micro using Python 3.11 (without pre-compiling it takes 0.068ms). --- server/fishtest/api.py | 117 ++++++++++++++++++++++++------------- server/tests/test_api.py | 37 ++++-------- server/tests/test_rundb.py | 2 +- 3 files changed, 89 insertions(+), 67 deletions(-) diff --git a/server/fishtest/api.py b/server/fishtest/api.py index 365f112a1..7bbc1fd8e 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -14,7 +14,7 @@ ) from pyramid.response import Response from pyramid.view import exception_view_config, view_config, view_defaults -from vtjson import _validate, lax, union +from vtjson import _validate, compile, intersect, interval, lax, regex """ Important note @@ -30,47 +30,86 @@ WORKER_VERSION = 222 +""" +begin api_schema +""" -def validate_request(request): - schema = { - "password": str, - "run_id?": str, - "task_id?": int, - "pgn?": str, - "message?": str, - "worker_info": { - "uname": str, - "architecture": [str, str], - "concurrency": int, - "max_memory": int, - "min_threads": int, - "username": str, - "version": int, - "python_version": [int, int, int], - "gcc_version": [int, int, int], - "compiler": union("g++", "clang++"), - "unique_key": str, - "modified": bool, - "near_github_api_limit": bool, - "ARCH": str, - "nps": float, +run_id = regex(r"[a-f0-9]{24}", name="run_id") +uuid = regex(r"[0-9a-zA-z]{2,8}(-[0-9a-f]{4}){3}-[0-9a-f]{12}", name="uuid") + +uint = intersect(int, interval(0, ...)) +suint = intersect(int, interval(1, ...)) +ufloat = intersect(float, interval(0.0, ...)) + + +def valid_results(R): + l, d, w = R["losses"], R["draws"], R["wins"] + R = R["pentanomial"] + return ( + l + d + w == 2 * sum(R) + and w - l == 2 * R[4] + R[3] - R[1] - 2 * R[0] + and R[3] + 2 * R[2] + R[1] >= d >= R[3] + R[1] + ) + + +def valid_spsa_results(R): + return R["wins"] + R["losses"] + R["draws"] == R["num_games"] + + +api_schema = { + "password": str, + "run_id?": run_id, + "task_id?": uint, + "pgn?": str, + "message?": str, + "worker_info": { + "uname": str, + "architecture": [str, str], + "concurrency": suint, + "max_memory": suint, + "min_threads": suint, + "username": str, + "version": uint, + "python_version": [uint, uint, uint], + "gcc_version": [uint, uint, uint], + "compiler": {"g++", "clang++"}, + "unique_key": uuid, + "modified": bool, + "near_github_api_limit": bool, + "ARCH": str, + "nps": ufloat, + }, + "spsa?": intersect( + { + "wins": uint, + "losses": uint, + "draws": uint, + "num_games": uint, }, - "spsa?": { - "wins": int, - "losses": int, - "draws": int, - "num_games": int, + valid_spsa_results, + ), + "stats?": intersect( + { + "wins": uint, + "losses": uint, + "draws": uint, + "crashes": uint, + "time_losses": uint, + "pentanomial": [uint, uint, uint, uint, uint], }, - "stats?": { - "wins": int, - "losses": int, - "draws": int, - "crashes": int, - "time_losses": int, - "pentanomial": [int, int, int, int, int], - }, - } - return _validate(schema, request, "request") + valid_results, + ), +} + +api_schema = compile(api_schema) + +""" +end api_schema +""" + + +def validate_request(request): + return _validate(api_schema, request, "request") # Avoids exposing sensitive data about the workers to the client and skips some heavy data. diff --git a/server/tests/test_api.py b/server/tests/test_api.py index bc56a8dee..27e9d2c8a 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -7,7 +7,7 @@ from datetime import datetime, timezone from fishtest.api import WORKER_VERSION, ApiView -from pyramid.httpexceptions import HTTPUnauthorized +from pyramid.httpexceptions import HTTPBadRequest, HTTPUnauthorized from pyramid.testing import DummyRequest from util import get_rundb @@ -99,7 +99,7 @@ def setUpClass(self): # Set up an API user (a worker) self.username = "JoeUserWorker" self.password = "secret" - self.unique_key = "unique key" + self.unique_key = "amaya-5a28-4b7d-b27b-d78d97ecf11a" self.remote_addr = "127.0.0.1" self.country_code = "US" self.concurrency = 7 @@ -123,7 +123,7 @@ def setUpClass(self): 0, ], "compiler": "g++", - "unique_key": "unique key", + "unique_key": "amaya-5a28-4b7d-b27b-d78d97ecf11a", "modified": True, "near_github_api_limit": False, "ARCH": "?", @@ -271,8 +271,8 @@ def test_update_task(self): "time_losses": 0, "pentanomial": [0, 0, d // 2, 0, w // 2], } - response = ApiView(cleanup(request)).update_task() - self.assertFalse(response["task_alive"]) + with self.assertRaises(HTTPBadRequest): + response = ApiView(cleanup(request)).update_task() request.json_body["stats"] = { "wins": w + 2, @@ -282,27 +282,10 @@ def test_update_task(self): "time_losses": 0, "pentanomial": [0, 0, d // 2, 0, w // 2 + 1], } - response = ApiView(cleanup(request)).update_task() - self.assertFalse(response["task_alive"]) - response = ApiView(cleanup(request)).update_task() - self.assertTrue("info" in response) - print(response["info"]) - - # revive the task - run["tasks"][0]["active"] = True - self.rundb.buffer(run, True) - - request.json_body["stats"] = { - "wins": w + 2, - "draws": d, - "losses": 0, - "crashes": 0, - "time_losses": 0, - "pentanomial": [0, 0, d // 2, 0, w // 2 + 1], - } response = ApiView(cleanup(request)).update_task() self.assertTrue(response["task_alive"]) + # Go back in time request.json_body["stats"] = { "wins": w, @@ -479,7 +462,7 @@ def setUpClass(self): # Set up an API user (a worker) self.username = "JoeUserWorker" self.password = "secret" - self.unique_key = "unique key" + self.unique_key = "amaya-5a28-4b7d-b27b-d78d97ecf11a" self.remote_addr = "127.0.0.1" self.concurrency = 7 @@ -502,7 +485,7 @@ def setUpClass(self): 0, ], "compiler": "g++", - "unique_key": "unique key", + "unique_key": "amaya-5a28-4b7d-b27b-d78d97ecf11a", "near_github_api_limit": False, "modified": True, "ARCH": "?", @@ -610,8 +593,8 @@ def test_auto_purge_runs(self): self.assertEqual(task_start2, task_size1) # Finish task 2 of 2 - n_wins = task_size2 // 5 - n_losses = task_size2 // 5 + n_wins = 2 * ((task_size2 // 5) // 2) + n_losses = 2 * ((task_size2 // 5) // 2) n_draws = task_size2 - n_wins - n_losses request = self.correct_password_request( diff --git a/server/tests/test_rundb.py b/server/tests/test_rundb.py index 2da9f8d25..eb5c44c21 100644 --- a/server/tests/test_rundb.py +++ b/server/tests/test_rundb.py @@ -36,7 +36,7 @@ def setUp(self): 3, 0, ], - "unique_key": "unique key", + "unique_key": "amaya-5a28-4b7d-b27b-d78d97ecf11a", "near_github_api_limit": False, "modified": True, "ARCH": "?",