diff --git a/server/fishtest/actiondb.py b/server/fishtest/actiondb.py index b20c4b198..b35c6c935 100644 --- a/server/fishtest/actiondb.py +++ b/server/fishtest/actiondb.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone -from fishtest.util import hex_print, union, validate, worker_name +from fishtest.util import hex_print, worker_name +from fishtest.vtjson import _validate, union from pymongo import DESCENDING schema = union( @@ -280,7 +281,7 @@ def block_worker(self, username=None, worker=None, message=None): def insert_action(self, **action): if "run_id" in action: action["run_id"] = str(action["run_id"]) - ret = validate(schema, action, "action", strict=True) + ret = _validate(schema, action, "action") if ret == "": action["time"] = datetime.now(timezone.utc).timestamp() self.actions.insert_one(action) diff --git a/server/fishtest/api.py b/server/fishtest/api.py index 90b8c4255..5e81d2d41 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -5,7 +5,8 @@ from datetime import datetime, timezone from fishtest.stats.stat_util import SPRT_elo -from fishtest.util import optional_key, union, validate, worker_name +from fishtest.util import worker_name +from fishtest.vtjson import _validate, lax, union from pyramid.httpexceptions import ( HTTPBadRequest, HTTPFound, @@ -33,10 +34,10 @@ def validate_request(request): schema = { "password": str, - optional_key("run_id"): str, - optional_key("task_id"): int, - optional_key("pgn"): str, - optional_key("message"): str, + "run_id?": str, + "task_id?": int, + "pgn?": str, + "message?": str, "worker_info": { "uname": str, "architecture": [str, str], @@ -54,13 +55,13 @@ def validate_request(request): "ARCH": str, "nps": float, }, - optional_key("spsa"): { + "spsa?": { "wins": int, "losses": int, "draws": int, "num_games": int, }, - optional_key("stats"): { + "stats?": { "wins": int, "losses": int, "draws": int, @@ -69,7 +70,7 @@ def validate_request(request): "pentanomial": [int, int, int, int, int], }, } - return validate(schema, request, "request", strict=True) + return _validate(schema, request, "request") # Avoids exposing sensitive data about the workers to the client and skips some heavy data. @@ -137,8 +138,8 @@ def validate_username_password(self, api): self.handle_error("request is not json encoded") # Is the request syntactically correct? - schema = {"password": str, "worker_info": {"username": str}} - self.handle_error(validate(schema, self.request_body, "request")) + schema = lax({"password": str, "worker_info": {"username": str}}) + self.handle_error(_validate(schema, self.request_body, "request")) # is the supplied password correct? token = self.request.userdb.authenticate( diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index ff62e8f4e..2b7e61158 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -33,6 +33,7 @@ update_residuals, worker_name, ) +from fishtest.vtjson import _validate, ip_address, number, regex, union, url from fishtest.workerdb import WorkerDb from pymongo import DESCENDING, MongoClient @@ -42,6 +43,196 @@ last_rundb = None +# This schema only matches new runs. The old runs are not +# compatible with it. For documentation purposes it would +# also be useful to have a "universal schema" that matches +# all the runs in the db. +# To make this practical we will eventually put all schemas +# in a separate module "schemas.py". + +net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name") +tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc") +str_int = regex(r"[1-9]\d*", name="str_int") +sha = regex(r"[a-f0-9]{40}", name="sha") +country_code = regex(r"[A-Z][A-Z]", name="country_code") +run_id = regex(r"[a-f0-9]{24}", name="run_id") + +worker_info_schema = { + "uname": str, + "architecture": [str, str], + "concurrency": int, + "max_memory": int, + "min_threads": int, + "username": str, + "version": int, + "python_version": [int, int, int], + "gcc_version": [int, int, int], + "compiler": union("clang++", "g++"), + "unique_key": str, + "modified": bool, + "ARCH": str, + "nps": number, + "near_github_api_limit": bool, + "remote_addr": ip_address, + "country_code": union(country_code, "?"), +} + +results_schema = { + "wins": int, + "losses": int, + "draws": int, + "crashes": int, + "time_losses": int, + "pentanomial": [int, int, int, int, int], +} + +schema = { + "_id?": ObjectId, + "start_time": datetime, + "last_updated": datetime, + "tc_base": number, + "base_same_as_master": bool, + "rescheduled_from?": run_id, + "approved": bool, + "approver": str, + "finished": bool, + "deleted": bool, + "failed": bool, + "is_green": bool, + "is_yellow": bool, + "workers": int, + "cores": int, + "results": results_schema, + "results_info?": { + "style": str, + "info": [str, ...], + }, + "args": { + "base_tag": str, + "new_tag": str, + "base_net": net_name, + "new_net": net_name, + "num_games": int, + "tc": tc, + "new_tc": tc, + "book": str, + "book_depth": str_int, + "threads": int, + "resolved_base": sha, + "resolved_new": sha, + "msg_base": str, + "msg_new": str, + "base_options": str, + "new_options": str, + "info": str, + "base_signature": str_int, + "new_signature": str_int, + "username": str, + "tests_repo": url, + "auto_purge": bool, + "throughput": number, + "itp": number, + "priority": number, + "adjudication": bool, + "sprt?": { + "alpha": 0.05, + "beta": 0.05, + "elo0": number, + "elo1": number, + "elo_model": "normalized", + "state": union("", "accepted", "rejected"), + "llr": number, + "batch_size": int, + "lower_bound": -math.log(19), + "upper_bound": math.log(19), + "lost_samples?": int, + "illegal_update?": int, + "overshoot?": { + "last_update": int, + "skipped_updates": int, + "ref0": number, + "m0": number, + "sq0": number, + "ref1": number, + "m1": number, + "sq1": number, + }, + }, + "spsa?": { + "A": number, + "alpha": number, + "gamma": number, + "raw_params": str, + "iter": int, + "num_iter": int, + "params": [ + { + "name": str, + "start": number, + "min": number, + "max": number, + "c_end": number, + "r_end": number, + "c": number, + "a_end": number, + "a": number, + "theta": number, + }, + ..., + ], + "param_history?": [ + [{"theta": number, "R": number, "c": number}, ...], + ..., + ], + }, + }, + "tasks": [ + { + "num_games": int, + "active": bool, + "last_updated": datetime, + "start": int, + "residual?": number, + "residual_color?": str, + "bad?": True, + "stats": results_schema, + "worker_info": worker_info_schema, + }, + ..., + ], + "bad_tasks?": [ + { + "num_games": int, + "active": False, + "last_updated": datetime, + "start": int, + "residual": number, + "residual_color": str, + "bad": True, + "task_id": int, + "stats": results_schema, + "worker_info": worker_info_schema, + }, + ..., + ], +} + +# Avoid leaking too many things into the global scope +del ( + country_code, + ip_address, + number, + regex, + results_schema, + run_id, + sha, + str_int, + tc, + union, + url, + worker_info_schema, +) + def get_port(): params = {} @@ -241,6 +432,12 @@ def new_run( if rescheduled_from: new_run["rescheduled_from"] = rescheduled_from + valid = _validate(schema, new_run, "run") + if valid != "": + message = f"The new run object does not _validate: {valid}" + print(message, flush=True) + raise Exception(message) + return self.runs.insert_one(new_run).inserted_id def upload_pgn(self, run_id, pgn_zip): @@ -613,6 +810,7 @@ def compute_results(self, run): """ This is used in purge_run and also to verify the incrementally updated results when a run is finished.""" + results = {"wins": 0, "losses": 0, "draws": 0, "crashes": 0, "time_losses": 0} has_pentanomial = True @@ -1366,6 +1564,12 @@ def stop_run(self, run_id): run["cores"] = 0 run["workers"] = 0 run["finished"] = True + valid = _validate(schema, run, "run") + if valid != "": + print(f"The run object {run_id} does not validate: {valid}", flush=True) + # We are not confident enough to enable this... + # assert False + self.buffer(run, True) # Publish the results of the run to the Fishcooking forum post_in_fishcooking_results(run) diff --git a/server/fishtest/userdb.py b/server/fishtest/userdb.py index 688592eb1..6ce13438f 100644 --- a/server/fishtest/userdb.py +++ b/server/fishtest/userdb.py @@ -4,24 +4,31 @@ from datetime import datetime, timezone from bson.objectid import ObjectId -from fishtest.util import optional_key, validate +from fishtest.vtjson import _validate, email, union, url from pymongo import ASCENDING schema = { - optional_key("_id"): ObjectId, + "_id?": ObjectId, "username": str, "password": str, "registration_time": datetime, "blocked": bool, - "email": str, - "groups": list, - "tests_repo": str, + "email": email, + "groups": [str, ...], + "tests_repo": union("", url), "machine_limit": int, } DEFAULT_MACHINE_LIMIT = 16 +def validate_user(user): + valid = _validate(schema, user, "user") + if valid != "": + print(valid, flush=True) + assert False + + class UserDb: def __init__(self, db): self.db = db @@ -89,7 +96,7 @@ def get_user_groups(self, username): def add_user_group(self, username, group): user = self.find(username) user["groups"].append(group) - assert validate(schema, user, "user", strict=True) == "" + validate_user(user) self.users.replace_one({"_id": user["_id"]}, user) self.clear_cache() @@ -108,7 +115,7 @@ def create_user(self, username, password, email): "tests_repo": "", "machine_limit": DEFAULT_MACHINE_LIMIT, } - assert validate(schema, user, "user", strict=True) == "" + validate_user(user) self.users.insert_one(user) self.last_pending_time = 0 @@ -117,7 +124,7 @@ def create_user(self, username, password, email): return False def save_user(self, user): - assert validate(schema, user, "user", strict=True) == "" + validate_user(user) self.users.replace_one({"_id": user["_id"]}, user) self.last_pending_time = 0 self.clear_cache() diff --git a/server/fishtest/util.py b/server/fishtest/util.py index 1a830cb7b..7b4064701 100644 --- a/server/fishtest/util.py +++ b/server/fishtest/util.py @@ -465,86 +465,6 @@ def password_strength(password, *args): return False, "Non-empty password required" -class optional_key: - def __init__(self, key): - self.key = key - - -class union: - def __init__(self, *schemas): - self.schemas = schemas - - def __validate__(self, object, name, strict=False): - messages = [] - for schema in self.schemas: - message = validate(schema, object, name, strict=strict) - if message == "": - return "" - else: - messages.append(message) - return " and ".join(messages) - - -def _keys(dict): - ret = set() - for k in dict: - if isinstance(k, optional_key): - ret.add(k.key) - else: - ret.add(k) - return ret - - -def validate(schema, object, name, strict=False): - if hasattr(schema, "__validate__"): # duck typing - return schema.__validate__(object, name, strict=strict) - elif isinstance(schema, type): - if not isinstance(object, schema): - return f"{name} is not of type {schema.__name__}" - else: - return "" - elif isinstance(schema, (list, tuple)): - if not isinstance(schema, type(object)): - return f"{name} is not of type {type(schema).__name}" - l = len(object) - if strict and l != len(schema): - return f"{name} does not have length {len(schema)}" - for i in range(len(schema)): - name_ = f"{name}[{i}]" - if i >= l: - return f"{name_} does not exist" - else: - ret = validate(schema[i], object[i], name_, strict=strict) - if ret != "": - return ret - return "" - elif isinstance(schema, dict): - if not isinstance(schema, type(object)): - return f"{name} is not of type {type(schema).__name}" - if strict: - _k = _keys(schema) - for x in object: - if x not in _k: - return f"{name}['{x}'] is not in the schema" - for k in schema: - k_ = k - if isinstance(k, optional_key): - k_ = k.key - if k_ not in object: - continue - name_ = f"{name}['{k_}']" - if k_ not in object: - return f"{name_} is missing" - else: - ret = validate(schema[k], object[k_], name_, strict=strict) - if ret != "": - return ret - return "" - elif object != schema: - return f"{name} is not equal to {repr(schema)}" - return "" - - # Workaround for a bug in pyramid.request.cookies. # Chrome may send different cookies with the same name. # The one that applies is the first one (the one with the diff --git a/server/fishtest/vtjson.py b/server/fishtest/vtjson.py new file mode 100644 index 000000000..ac388d984 --- /dev/null +++ b/server/fishtest/vtjson.py @@ -0,0 +1,646 @@ +import datetime +import ipaddress +import math +import re +import urllib.parse + +import dns.resolver +import email_validator +import idna + + +class ValidationError(Exception): + pass + + +class SchemaError(Exception): + pass + + +try: + from types import GenericAlias as _GenericAlias +except ImportError: + # For compatibility with older Pythons + class _GenericAlias(type): + pass + + +__version__ = "1.3.5" + + +def _c(s): + ss = str(s) + if len(ss) > 0: + c = ss[-1] + else: + c = "" + if len(ss) < 100: + ret = ss + else: + ret = f"{ss[:100]}...[TRUNCATED]..." + if not isinstance(s, str) and c in r"])}": + ret += c + if isinstance(s, str): + return repr(ret) + else: + return ret + + +def _wrong_type_message(object, name, type_name, explanation=None): + message = f"{name} (value:{_c(object)}) is not of type '{type_name}'" + if explanation is not None: + message += f": {explanation}" + return message + + +def _keys2(dict): + ret = set() + for k in dict: + if isinstance(k, optional_key): + ret.add((k.key, k, True)) + elif isinstance(k, str) and len(k) > 0 and k[-1] == "?": + ret.add((k[:-1], k, True)) + else: + ret.add((k, k, False)) + return ret + + +def _keys(dict): + return {k[0] for k in _keys2(dict)} + + +class _validate_meta(type): + def __instancecheck__(cls, object): + valid = _validate(cls.__schema__, object, "object", strict=cls.__strict__) + if cls.__debug__ and valid != "": + print(f"DEBUG: {valid}") + return valid == "" + + +def make_type(schema, name=None, strict=True, debug=False): + if name is None: + if hasattr(schema, "__name__"): + name = schema.__name__ + else: + name = "schema" + return _validate_meta( + name, (), {"__schema__": schema, "__strict__": strict, "__debug__": debug} + ) + + +class optional_key: + def __init__(self, key): + self.key = key + + +class union: + def __init__(self, *schemas): + self.schemas = [_compile(s) for s in schemas] + + def __validate__(self, object, name, strict): + messages = [] + for schema in self.schemas: + message = schema.__validate__(object, name=name, strict=strict) + if message == "": + return "" + else: + messages.append(message) + return " and ".join(messages) + + +class intersect: + def __init__(self, *schemas): + self.schemas = [_compile(s) for s in schemas] + + def __validate__(self, object, name, strict): + for schema in self.schemas: + message = schema.__validate__(object, name=name, strict=strict) + if message != "": + return message + return "" + + +class complement: + def __init__(self, schema): + self.schema = _compile(schema) + + def __validate__(self, object, name, strict): + message = self.schema.__validate__(object, name=name, strict=strict) + if message != "": + return "" + else: + return f"{name} does not match the complemented schema" + + +class lax: + def __init__(self, schema): + self.schema = _compile(schema) + + def __validate__(self, object, name, strict): + return self.schema.__validate__(object, name=name, strict=False) + + +class strict: + def __init__(self, schema): + self.schema = _compile(schema) + + def __validate__(self, object, name, strict): + return self.schema.__validate__(object, name=name, strict=True) + + +class quote: + def __init__(self, schema): + self.schema = _object(schema) + + def __validate__(self, object, name, strict): + return self.schema.__validate__(object, name, strict) + + +class set_name: + def __init__(self, schema, name): + self.schema = _compile(schema) + self.__name__ = name + + def __validate__(self, object, name, strict): + message = self.schema.__validate__(object, name=name, strict=strict) + if message != "": + return _wrong_type_message(object, name, self.__name__) + return "" + + +class regex: + def __init__(self, regex, name=None, fullmatch=True, flags=0): + self.regex = regex + self.fullmatch = fullmatch + if name is not None: + self.__name__ = name + else: + _flags = "" if flags == 0 else f", flags={flags}" + _fullmatch = "" if fullmatch else ", fullmatch=False" + self.__name__ = f"regex({repr(regex)}{_fullmatch}{_flags})" + + schema_error = False + try: + self.pattern = re.compile(regex, flags) + except Exception as e: + schema_error = True + message = str(e) + if schema_error: + _name = f" (name: {repr(name)})" if name is not None else "" + raise SchemaError( + f"{regex}{_name} is an invalid regular expression: {message}" + ) + + def __validate__(self, object, name, strict): + try: + if self.fullmatch and self.pattern.fullmatch(object): + return "" + elif not self.fullmatch and self.pattern.match(object): + return "" + except Exception: + pass + return _wrong_type_message(object, name, self.__name__) + + +class interval: + def __init__(self, lb, ub): + self.lb = lb + self.ub = ub + self.lb_s = "..." if lb == ... else repr(lb) + self.ub_s = "..." if ub == ... else repr(ub) + + if lb is ... and ub is ...: + self.__validate__ = self.__validate_none__ + elif lb is ...: + self.__validate__ = self.__validate_ub__ + elif ub is ...: + self.__validate__ = self.__validate_lb__ + else: + schema_error = False + try: + lb <= ub + except Exception: + schema_error = True + if schema_error: + raise SchemaError( + f"The upper and lower bound in the interval" + f" [{self.lb_s},{self.ub_s}] are incomparable" + ) + + def message(self, name, object): + return ( + f"{name} (value:{_c(object)}) is not in the interval " + f"[{self.lb_s},{self.ub_s}]" + ) + + def __validate__(self, object, name, strict): + try: + if self.lb <= object <= self.ub: + return "" + else: + return self.message(name, object) + except Exception as e: + return f"{self.message(name, object)}: {str(e)}" + + def __validate_ub__(self, object, name, strict): + try: + if object <= self.ub: + return "" + else: + return self.message(name, object) + except Exception as e: + return f"{self.message(name, object)}: {str(e)}" + + def __validate_lb__(self, object, name, strict): + try: + if object >= self.lb: + return "" + else: + return self.message(name, object) + except Exception as e: + return f"{self.message(name, object)}: {str(e)}" + + def __validate_none__(self, object, name, strict): + return "" + + +def _compile(schema): + if hasattr(schema, "__validate__"): + return schema + elif isinstance(schema, type) or isinstance(schema, _GenericAlias): + return _type(schema) + elif callable(schema): + return _callable(schema) + elif isinstance(schema, tuple) or isinstance(schema, list): + return _sequence(schema) + elif isinstance(schema, dict): + return _dict(schema) + elif isinstance(schema, set): + return union(*schema) + else: + return _object(schema) + + +def _validate(schema, object, name="object", strict=True): + schema = _compile(schema) + return schema.__validate__(object, name=name, strict=strict) + + +def validate(schema, object, name="object", strict=True): + message = _validate(schema, object, name=name, strict=strict) + if message != "": + raise ValidationError(message) + + +# Some predefined schemas + + +class number: + @staticmethod + def __validate__(object, name, strict): + return _number.__validate__(object, name, strict) + + def __init__(self): + self.__validate__ = self.__validate2__ + + def __validate2__(self, object, name, strict): + if isinstance(object, int) or isinstance(object, float): + return "" + else: + return _wrong_type_message(object, name, "number") + + +_number = number() + + +class email: + _resolver = email_validator.caching_resolver(timeout=10) + + @staticmethod + def __validate__(object, name, strict): + return _email.__validate__(object, name, strict) + + def __init__(self, *args, **kw): + self.args = args + self.kw = kw + if "dns_resolver" not in kw: + self.kw["dns_resolver"] = self._resolver + if "check_deliverability" not in kw: + self.kw["check_deliverability"] = False + self.__validate__ = self.__validate2__ + + def __validate2__(self, object, name, strict): + try: + email_validator.validate_email(object, *self.args, **self.kw) + return "" + except email_validator.EmailNotValidError as e: + return _wrong_type_message(object, name, "email", str(e)) + + +_email = email() + + +class ip_address: + @staticmethod + def __validate__(object, name, strict): + return _ip_address.__validate__(object, name, strict) + + def __init__(self): + self.__validate__ = self.__validate2__ + + def __validate2__(self, object, name, strict): + try: + ipaddress.ip_address(object) + return "" + except ValueError: + return _wrong_type_message(object, name, "ip_address") + + +_ip_address = ip_address() + + +class url: + @staticmethod + def __validate__(object, name, strict): + return _url.__validate__(object, name, strict) + + def __init__(self): + self.__validate__ = self.__validate2__ + + def __validate2__(self, object, name, strict): + result = urllib.parse.urlparse(object) + if all([result.scheme, result.netloc]): + return "" + return _wrong_type_message(object, name, "url") + + +_url = url() + + +class date_time: + @staticmethod + def __validate__(object, name, strict): + return _date_time.__validate__(object, name, strict) + + def __init__(self, format=None): + self.format = format + self.__validate__ = self.__validate2__ + if format is not None: + self.__name__ = f"date_time({repr(format)})" + else: + self.__name__ = "date_time" + + def __validate2__(self, object, name, strict): + if self.format is not None: + try: + datetime.datetime.strptime(object, self.format) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + else: + try: + datetime.datetime.fromisoformat(object) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + return "" + + +_date_time = date_time() + + +class date: + @staticmethod + def __validate__(object, name, strict): + return _date.__validate__(object, name, strict) + + def __init__(self): + self.__validate__ = self.__validate2__ + self.__name__ = "date" + + def __validate2__(self, object, name, strict): + try: + datetime.date.fromisoformat(object) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + return "" + + +_date = date() + + +class time: + @staticmethod + def __validate__(object, name, strict): + return _time.__validate__(object, name, strict) + + def __init__(self): + self.__validate__ = self.__validate2__ + self.__name__ = "time" + + def __validate2__(self, object, name, strict): + try: + datetime.time.fromisoformat(object) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + return "" + + +_time = time() + + +class domain_name: + def __validate__(object, name, strict): + return _domain_name.__validate__(object, name, strict) + + def __init__(self, ascii_only=True, resolve=False): + self.re_ascii = re.compile(r"[\x00-\x7F]*") + self.ascii_only = ascii_only + self.resolve = resolve + self.__validate__ = self.__validate2__ + arg_string = "" + if not ascii_only: + arg_string += ", ascii_only=False" + if resolve: + arg_string += ", resolve=True" + if arg_string != "": + arg_string = arg_string[2:] + self.__name__ = ( + "domain_name" if not arg_string else f"domain_name({arg_string})" + ) + self._resolver = dns.resolver.Resolver() + self._resolver.cache = dns.resolver.LRUCache() + + def __validate2__(self, object, name, strict): + if self.ascii_only: + if not self.re_ascii.fullmatch(object): + return _wrong_type_message( + object, name, self.__name__, "Non-ascii characters" + ) + try: + idna.encode(object, uts46=False) + except idna.core.IDNAError as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + + if self.resolve: + try: + self._resolver.resolve(object) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + return "" + + +_domain_name = domain_name() + + +class _dict: + def __init__(self, schema): + self.schema = {} + for k, v in schema.items(): + self.schema[k] = _compile(v) + self.keys = _keys(self.schema) + self.keys2 = _keys2(self.schema) + + def __validate__(self, object, name, strict): + if type(object) is not dict: + return _wrong_type_message(object, name, type(self.schema).__name__) + if strict: + for x in object: + if x not in self.keys: + return f"{name}['{x}'] is not in the schema" + for k_, k, o in self.keys2: + name_ = f"{name}['{k_}']" + if k not in object: + if o: + continue + else: + return f"{name_} is missing" + else: + ret = self.schema[k_].__validate__(object[k], name=name_, strict=strict) + if ret != "": + return ret + return "" + + def __str__(self): + return str(self.schema) + + +class _type: + def __init__(self, schema): + self.schema = schema + if isinstance(schema, _GenericAlias): + raise SchemaError("Parametrized generics are not supported!") + + def __validate__(self, object, name, strict): + try: + if not isinstance(object, self.schema): + return _wrong_type_message(object, name, self.schema.__name__) + else: + return "" + except Exception as e: + return f"{self.schema} is not a valid type: {str(e)}" + + def __str__(self): + return self.type.__name__ + + +class _sequence: + def __init__(self, schema): + self.type_schema = type(schema) + self.schema = [_compile(o) if o is not ... else ... for o in schema] + if len(schema) > 0 and schema[-1] is ...: + if len(schema) >= 2: + self.fill = self.schema[-2] + self.schema = self.schema[:-2] + else: + self.fill = _type(object) + self.schema = [] + self.__validate__ = self.__validate_ellipsis__ + + def __validate__(self, object, name, strict): + if self.type_schema is not type(object): + return _wrong_type_message(object, name, type(self.schema).__name__) + ls = len(self.schema) + lo = len(object) + if strict: + if lo > ls: + return f"{name}[{ls}] is not in the schema" + if ls > lo: + return f"{name}[{lo}] is missing" + for i in range(ls): + name_ = f"{name}[{i}]" + ret = self.schema[i].__validate__(object[i], name_, strict) + if ret != "": + return ret + return "" + + def __validate_ellipsis__(self, object, name, strict): + if self.type_schema is not type(object): + return _wrong_type_message(object, name, type(self.schema).__name__) + ls = len(self.schema) + lo = len(object) + if ls > lo: + return f"{name}[{lo}] is missing" + for i in range(ls): + name_ = f"{name}[{i}]" + ret = self.schema[i].__validate__(object[i], name_, strict) + if ret != "": + return ret + for i in range(ls + 1, lo): + name_ = f"{name}[{i}]" + ret = self.fill.__validate__(object[i], name_, strict) + if ret != "": + return ret + return "" + + def __str__(self): + return str(self.schema) + + +class _object: + def __init__(self, schema): + self.schema = schema + if isinstance(schema, float): + self.__validate__ = self.__validate_float__ + + def message(self, name, object): + return f"{name} (value:{_c(object)}) is not equal to {repr(self.schema)}" + + def __validate__(self, object, name, strict): + if object != self.schema: + return self.message(name, object) + return "" + + def message_float(self, name, object): + return f"{name} (value:{_c(object)}) is not close to {repr(self.schema)}" + + def __validate_float__(self, object, name, strict): + try: + if math.isclose(self.schema, object): + return "" + else: + return self.message_float(name, object) + except Exception: + return self.message_float(name, object) + + def __str__(self): + return str(self.schema) + + +class _callable: + def __init__(self, schema): + self.schema = schema + try: + self.__name__ = self.schema.__name__ + except Exception: + self.__name__ = self.schema + + def __validate__(self, object, name, strict): + try: + if self.schema(object): + return "" + else: + return _wrong_type_message(object, name, self.__name__) + except Exception as e: + return _wrong_type_message(object, name, self.__name__, str(e)) + + def __str__(self): + return str(self.schema) diff --git a/server/fishtest/workerdb.py b/server/fishtest/workerdb.py index 5e726b0bd..1e2bc8842 100644 --- a/server/fishtest/workerdb.py +++ b/server/fishtest/workerdb.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from fishtest.util import validate +from fishtest.vtjson import _validate schema = { "worker_name": str, @@ -40,7 +40,7 @@ def update_worker(self, worker_name, blocked=None, message=None): "message": message, "last_updated": datetime.now(timezone.utc), } - assert validate(schema, r, "worker", strict=True) == "" + assert _validate(schema, r, "worker") == "" self.workers.replace_one({"worker_name": worker_name}, r, upsert=True) def get_blocked_workers(self): diff --git a/server/tests/test_api.py b/server/tests/test_api.py index 72d5b84a2..bc56a8dee 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -28,12 +28,24 @@ def new_run(self, add_tasks=0): "10+0.01", "10+0.01", "book", - 10, + "10", 1, "", "", + info="The ultimate patch", + resolved_base="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + resolved_new="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + msg_base="Bad stuff", + msg_new="Super stuff", + base_signature="123456", + new_signature="654321", + base_net="nn-0000000000a0.nnue", + new_net="nn-0000000000a0.nnue", + rescheduled_from="653db116cc309ae839563103", + base_same_as_master=False, + tests_repo="https://google.com", + auto_purge=False, username="travis", - tests_repo="travis", start_time=datetime.now(timezone.utc), ) run = self.rundb.get_run(run_id) @@ -41,11 +53,23 @@ def new_run(self, add_tasks=0): if add_tasks > 0: run["workers"] = run["cores"] = 0 for i in range(add_tasks): + worker_info = copy.deepcopy(self.worker_info) + worker_info["remote_addr"] = self.remote_addr + worker_info["country_code"] = self.country_code task = { "num_games": self.chunk_size, - "stats": {"wins": 0, "draws": 0, "losses": 0, "crashes": 0}, + "stats": { + "wins": 0, + "draws": 0, + "losses": 0, + "crashes": 0, + "time_losses": 0, + "pentanomial": [0, 0, 0, 0, 0], + }, "active": True, - "worker_info": copy.deepcopy(self.worker_info), + "last_updated": datetime.now(timezone.utc), + "start": 1234, + "worker_info": worker_info, } run["workers"] += 1 run["cores"] += self.worker_info["concurrency"] @@ -77,6 +101,7 @@ def setUpClass(self): self.password = "secret" self.unique_key = "unique key" self.remote_addr = "127.0.0.1" + self.country_code = "US" self.concurrency = 7 self.worker_info = { diff --git a/server/tests/test_rundb.py b/server/tests/test_rundb.py index db50f4307..2da9f8d25 100644 --- a/server/tests/test_rundb.py +++ b/server/tests/test_rundb.py @@ -51,21 +51,37 @@ def tearDown(self): def test_10_create_run(self): global run_id # STC + num_tasks = 4 + num_games = num_tasks * self.chunk_size + run_id_stc = self.rundb.new_run( "master", "master", - 100000, + num_games, "10+0.01", "10+0.01", "book", - 10, + "10", 1, "", "", + info="The ultimate patch", + resolved_base="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + resolved_new="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + msg_base="Bad stuff", + msg_new="Super stuff", + base_signature="123456", + new_signature="654321", + base_net="nn-0000000000a0.nnue", + new_net="nn-0000000000a0.nnue", + rescheduled_from="653db116cc309ae839563103", + base_same_as_master=False, + tests_repo="https://google.com", + auto_purge=False, username="travis", - tests_repo="travis", start_time=datetime.now(timezone.utc), ) + run = self.rundb.get_run(run_id_stc) run["finished"] = True task = { @@ -82,20 +98,30 @@ def test_10_create_run(self): run_id = self.rundb.new_run( "master", "master", - 100000, - "150+0.01", - "150+0.01", + num_games, + "10+0.01", + "10+0.01", "book", - 10, + "10", 1, "", "", + info="The ultimate patch", + resolved_base="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + resolved_new="347d613b0e2c47f90cbf1c5a5affe97303f1ac3d", + msg_base="Bad stuff", + msg_new="Super stuff", + base_signature="123456", + new_signature="654321", + base_net="nn-0000000000a0.nnue", + new_net="nn-0000000000a0.nnue", + rescheduled_from="653db116cc309ae839563103", + base_same_as_master=False, + tests_repo="https://google.com", + auto_purge=False, username="travis", - tests_repo="travis", start_time=datetime.now(timezone.utc), ) - print(" ") - print(run_id) run = self.rundb.get_run(run_id) task = { "num_games": self.chunk_size, diff --git a/server/tests/test_validate.py b/server/tests/test_validate.py deleted file mode 100644 index 458e702b8..000000000 --- a/server/tests/test_validate.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest - -from fishtest.util import _keys, optional_key, union, validate - - -class TestValidation(unittest.TestCase): - def test_keys(self): - schema = {optional_key("a"): 1, "b": 2, optional_key("c"): 3} - keys = _keys(schema) - self.assertEqual(keys, {"a", "b", "c"}) - - def test_strict(self): - schema = {optional_key("a"): 1, "b": 2} - name = "my_object" - object = {"b": 2, "c": 3} - valid = validate(schema, object, name, strict=True) - self.assertFalse(valid == "") - - object = {"a": 1, "c": 3} - valid = validate(schema, object, name, strict=True) - self.assertFalse(valid == "") - - object = {"a": 1, "b": 2} - valid = validate(schema, object, name, strict=True) - self.assertTrue(valid == "") - - object = {"b": 2} - valid = validate(schema, object, name, strict=True) - self.assertTrue(valid == "") - - def test_missing_keys(self): - schema = {optional_key("a"): 1, "b": 2} - name = "my_object" - object = {"b": 2, "c": 3} - valid = validate(schema, object, name, strict=False) - self.assertTrue(valid == "") - - object = {"a": 1, "c": 3} - valid = validate(schema, object, name, strict=False) - self.assertFalse(valid == "") - - object = {"a": 1, "b": 2} - valid = validate(schema, object, name, strict=False) - self.assertTrue(valid == "") - - object = {"b": 2} - valid = validate(schema, object, name, strict=False) - self.assertTrue(valid == "") - - def test_union(self): - schema = {optional_key("a"): 1, "b": union(2, 3)} - name = "my_object" - object = {"b": 2, "c": 3} - valid = validate(schema, object, name, strict=False) - self.assertTrue(valid == "") - - object = {"b": 4, "c": 3} - valid = validate(schema, object, name, strict=False) - self.assertFalse(valid == "") - - def test_validate(self): - class lower_case_string: - @staticmethod - def __validate__(object, name, strict=False): - if not isinstance(object, str): - return f"{name} is not a string" - for c in object: - if not ("a" <= c <= "z"): - return f"{c}, contained in the string {name}, is not a lower case letter" - return "" - - schema = lower_case_string - object = 1 - name = "my_object" - valid = validate(schema, object, name, strict=True) - self.assertFalse(valid == "") - - object = "aA" - valid = validate(schema, object, name, strict=True) - self.assertFalse(valid == "") - - object = "ab" - valid = validate(schema, object, name, strict=True) - self.assertTrue(valid == "") - - schema = {"a": lower_case_string} - object = {"a": "ab"} - valid = validate(schema, object, name, strict=True) - self.assertTrue(valid == "") - - object = {"a": "AA"} - valid = validate(schema, object, name, strict=True) - self.assertFalse(valid == "") - - -if __name__ == "__main__": - unittest.main()