From 1f4b6902477a9086df7841bb0e649add169cc3f1 Mon Sep 17 00:00:00 2001 From: Joost VandeVondele Date: Thu, 27 Aug 2020 17:23:12 +0200 Subject: [PATCH] Use black to reformat code this reformats the code using `black --exclude=worker/requests .` fixes https://github.com/glinscott/fishtest/issues/634 as with any tool, not all is perfect, but things are pretty consistent. I propose to apply this patch (or execute the above black command) as part of the next worker change. --- fishtest/fishtest/__init__.py | 171 +-- fishtest/fishtest/actiondb.py | 99 +- fishtest/fishtest/api.py | 444 +++--- fishtest/fishtest/helpers.py | 28 +- fishtest/fishtest/models.py | 11 +- fishtest/fishtest/rundb.py | 1784 ++++++++++++----------- fishtest/fishtest/stats/LLRcalc.py | 150 +- fishtest/fishtest/stats/brownian.py | 140 +- fishtest/fishtest/stats/sprt.py | 184 +-- fishtest/fishtest/stats/stat_util.py | 551 +++++--- fishtest/fishtest/userdb.py | 195 +-- fishtest/fishtest/util.py | 554 ++++---- fishtest/fishtest/views.py | 1950 ++++++++++++++------------ fishtest/run_all_tests.py | 8 +- fishtest/setup.py | 61 +- fishtest/tests/test_api.py | 910 ++++++------ fishtest/tests/test_run.py | 8 +- fishtest/tests/test_rundb.py | 199 ++- fishtest/tests/test_users.py | 207 +-- fishtest/tests/util.py | 18 +- fishtest/utils/clone_fish.py | 94 +- fishtest/utils/compact_actions.py | 45 +- fishtest/utils/create_indexes.py | 198 +-- fishtest/utils/create_pgndb.py | 8 +- fishtest/utils/current.py | 24 +- fishtest/utils/delta_update_users.py | 358 ++--- fishtest/utils/index_pending.py | 20 +- fishtest/utils/purge_pgn.py | 88 +- fishtest/utils/scavenge.py | 33 +- fishtest/utils/test_queries.py | 64 +- worker/games.py | 1606 ++++++++++++--------- worker/setup.py | 6 +- worker/test_worker.py | 85 +- worker/updater.py | 138 +- worker/worker.py | 537 +++---- 35 files changed, 5903 insertions(+), 5073 deletions(-) diff --git a/fishtest/fishtest/__init__.py b/fishtest/fishtest/__init__.py index f32ff8b12..0a019ded6 100644 --- a/fishtest/fishtest/__init__.py +++ b/fishtest/fishtest/__init__.py @@ -10,88 +10,93 @@ from fishtest.rundb import RunDb from fishtest import helpers + def main(global_config, **settings): - """ This function returns a Pyramid WSGI application. + """ This function returns a Pyramid WSGI application. """ - session_factory = UnencryptedCookieSessionFactoryConfig('fishtest') - config = Configurator(settings=settings, - session_factory=session_factory, - root_factory='fishtest.models.RootFactory') - config.include('pyramid_mako') - config.set_default_csrf_options(require_csrf=False) - - rundb = RunDb() - - def add_rundb(event): - event.request.rundb = rundb - event.request.userdb = rundb.userdb - event.request.actiondb = rundb.actiondb - - def add_renderer_globals(event): - event['h'] = helpers - - config.add_subscriber(add_rundb, NewRequest) - config.add_subscriber(add_renderer_globals, BeforeRender) - - # Authentication - def group_finder(username, request): - return request.userdb.get_user_groups(username) - - with open(os.path.expanduser('~/fishtest.secret'), 'r') as f: - secret = f.read() - config.set_authentication_policy( - AuthTktAuthenticationPolicy( - secret, callback=group_finder, hashalg='sha512', http_only=True)) - config.set_authorization_policy(ACLAuthorizationPolicy()) - - config.add_static_view('html', 'static/html', cache_max_age=3600) - config.add_static_view('css', 'static/css', cache_max_age=3600) - config.add_static_view('js', 'static/js', cache_max_age=3600) - config.add_static_view('img', 'static/img', cache_max_age=3600) - - config.add_route('home', '/') - config.add_route('login', '/login') - config.add_route('nn_upload', '/upload') - config.add_route('logout', '/logout') - config.add_route('signup', '/signup') - config.add_route('user', '/user/{username}') - config.add_route('profile', '/user') - config.add_route('pending', '/pending') - config.add_route('users', '/users') - config.add_route('users_monthly', '/users/monthly') - config.add_route('actions', '/actions') - config.add_route('nns', '/nns') - - config.add_route('tests', '/tests') - config.add_route('tests_machines', '/tests/machines') - config.add_route('tests_finished', '/tests/finished') - config.add_route('tests_run', '/tests/run') - config.add_route('tests_view', '/tests/view/{id}') - config.add_route('tests_view_spsa_history', '/tests/view/{id}/spsa_history') - config.add_route('tests_user', '/tests/user/{username}') - config.add_route('tests_stats', '/tests/stats/{id}') - - # Tests - actions - config.add_route('tests_modify', '/tests/modify') - config.add_route('tests_delete', '/tests/delete') - config.add_route('tests_stop', '/tests/stop') - config.add_route('tests_approve', '/tests/approve') - config.add_route('tests_purge', '/tests/purge') - - # API - config.add_route('api_request_task', '/api/request_task') - config.add_route('api_update_task', '/api/update_task') - config.add_route('api_failed_task', '/api/failed_task') - config.add_route('api_stop_run', '/api/stop_run') - config.add_route('api_request_version', '/api/request_version') - config.add_route('api_request_spsa', '/api/request_spsa') - config.add_route('api_active_runs', '/api/active_runs') - config.add_route('api_get_run', '/api/get_run/{id}') - config.add_route('api_upload_pgn', '/api/upload_pgn') - config.add_route('api_download_pgn', '/api/pgn/{id}') - config.add_route('api_download_pgn_100', '/api/pgn_100/{skip}') - config.add_route('api_download_nn', '/api/nn/{id}') - config.add_route('api_get_elo', '/api/get_elo/{id}') - - config.scan() - return config.make_wsgi_app() + session_factory = UnencryptedCookieSessionFactoryConfig("fishtest") + config = Configurator( + settings=settings, + session_factory=session_factory, + root_factory="fishtest.models.RootFactory", + ) + config.include("pyramid_mako") + config.set_default_csrf_options(require_csrf=False) + + rundb = RunDb() + + def add_rundb(event): + event.request.rundb = rundb + event.request.userdb = rundb.userdb + event.request.actiondb = rundb.actiondb + + def add_renderer_globals(event): + event["h"] = helpers + + config.add_subscriber(add_rundb, NewRequest) + config.add_subscriber(add_renderer_globals, BeforeRender) + + # Authentication + def group_finder(username, request): + return request.userdb.get_user_groups(username) + + with open(os.path.expanduser("~/fishtest.secret"), "r") as f: + secret = f.read() + config.set_authentication_policy( + AuthTktAuthenticationPolicy( + secret, callback=group_finder, hashalg="sha512", http_only=True + ) + ) + config.set_authorization_policy(ACLAuthorizationPolicy()) + + config.add_static_view("html", "static/html", cache_max_age=3600) + config.add_static_view("css", "static/css", cache_max_age=3600) + config.add_static_view("js", "static/js", cache_max_age=3600) + config.add_static_view("img", "static/img", cache_max_age=3600) + + config.add_route("home", "/") + config.add_route("login", "/login") + config.add_route("nn_upload", "/upload") + config.add_route("logout", "/logout") + config.add_route("signup", "/signup") + config.add_route("user", "/user/{username}") + config.add_route("profile", "/user") + config.add_route("pending", "/pending") + config.add_route("users", "/users") + config.add_route("users_monthly", "/users/monthly") + config.add_route("actions", "/actions") + config.add_route("nns", "/nns") + + config.add_route("tests", "/tests") + config.add_route("tests_machines", "/tests/machines") + config.add_route("tests_finished", "/tests/finished") + config.add_route("tests_run", "/tests/run") + config.add_route("tests_view", "/tests/view/{id}") + config.add_route("tests_view_spsa_history", "/tests/view/{id}/spsa_history") + config.add_route("tests_user", "/tests/user/{username}") + config.add_route("tests_stats", "/tests/stats/{id}") + + # Tests - actions + config.add_route("tests_modify", "/tests/modify") + config.add_route("tests_delete", "/tests/delete") + config.add_route("tests_stop", "/tests/stop") + config.add_route("tests_approve", "/tests/approve") + config.add_route("tests_purge", "/tests/purge") + + # API + config.add_route("api_request_task", "/api/request_task") + config.add_route("api_update_task", "/api/update_task") + config.add_route("api_failed_task", "/api/failed_task") + config.add_route("api_stop_run", "/api/stop_run") + config.add_route("api_request_version", "/api/request_version") + config.add_route("api_request_spsa", "/api/request_spsa") + config.add_route("api_active_runs", "/api/active_runs") + config.add_route("api_get_run", "/api/get_run/{id}") + config.add_route("api_upload_pgn", "/api/upload_pgn") + config.add_route("api_download_pgn", "/api/pgn/{id}") + config.add_route("api_download_pgn_100", "/api/pgn_100/{skip}") + config.add_route("api_download_nn", "/api/nn/{id}") + config.add_route("api_get_elo", "/api/get_elo/{id}") + + config.scan() + return config.make_wsgi_app() diff --git a/fishtest/fishtest/actiondb.py b/fishtest/fishtest/actiondb.py index 5c813faa3..5486f0d4a 100644 --- a/fishtest/fishtest/actiondb.py +++ b/fishtest/fishtest/actiondb.py @@ -3,52 +3,53 @@ class ActionDb: - def __init__(self, db): - self.db = db - self.actions = self.db['actions'] - - def get_actions(self, max_num, action=None, username=None): - q = {} - if action: - q['action'] = action - else: - q['action'] = {"$ne": 'update_stats'} - if username: - q['username'] = username - return self.actions.find(q, sort=[('_id', DESCENDING)], limit=max_num) - - def update_stats(self): - self._new_action('fishtest.system', 'update_stats', '') - - def new_run(self, username, run): - self._new_action(username, 'new_run', run) - - def upload_nn(self, username, network): - self._new_action(username, 'upload_nn', network) - - def modify_run(self, username, before, after): - self._new_action(username, 'modify_run', - {'before': before, 'after': after}) - - def delete_run(self, username, run): - self._new_action(username, 'delete_run', run) - - def stop_run(self, username, run): - self._new_action(username, 'stop_run', run) - - def approve_run(self, username, run): - self._new_action(username, 'approve_run', run) - - def purge_run(self, username, run): - self._new_action(username, 'purge_run', run) - - def block_user(self, username, data): - self._new_action(username, 'block_user', data) - - def _new_action(self, username, action, data): - self.actions.insert_one({ - 'username': username, - 'action': action, - 'data': data, - 'time': datetime.utcnow(), - }) + def __init__(self, db): + self.db = db + self.actions = self.db["actions"] + + def get_actions(self, max_num, action=None, username=None): + q = {} + if action: + q["action"] = action + else: + q["action"] = {"$ne": "update_stats"} + if username: + q["username"] = username + return self.actions.find(q, sort=[("_id", DESCENDING)], limit=max_num) + + def update_stats(self): + self._new_action("fishtest.system", "update_stats", "") + + def new_run(self, username, run): + self._new_action(username, "new_run", run) + + def upload_nn(self, username, network): + self._new_action(username, "upload_nn", network) + + def modify_run(self, username, before, after): + self._new_action(username, "modify_run", {"before": before, "after": after}) + + def delete_run(self, username, run): + self._new_action(username, "delete_run", run) + + def stop_run(self, username, run): + self._new_action(username, "stop_run", run) + + def approve_run(self, username, run): + self._new_action(username, "approve_run", run) + + def purge_run(self, username, run): + self._new_action(username, "purge_run", run) + + def block_user(self, username, data): + self._new_action(username, "block_user", data) + + def _new_action(self, username, action, data): + self.actions.insert_one( + { + "username": username, + "action": action, + "data": data, + "time": datetime.utcnow(), + } + ) diff --git a/fishtest/fishtest/api.py b/fishtest/fishtest/api.py index 1d4ad987c..adf3d3a67 100644 --- a/fishtest/fishtest/api.py +++ b/fishtest/fishtest/api.py @@ -17,239 +17,225 @@ def strip_run(run): - run = copy.deepcopy(run) - if 'tasks' in run: - del run['tasks'] - if 'bad_tasks' in run: - del run['bad_tasks'] - if 'spsa' in run['args'] and 'param_history' in run['args']['spsa']: - del run['args']['spsa']['param_history'] - run['_id'] = str(run['_id']) - run['start_time'] = str(run['start_time']) - run['last_updated'] = str(run['last_updated']) - return run + run = copy.deepcopy(run) + if "tasks" in run: + del run["tasks"] + if "bad_tasks" in run: + del run["bad_tasks"] + if "spsa" in run["args"] and "param_history" in run["args"]["spsa"]: + del run["args"]["spsa"]["param_history"] + run["_id"] = str(run["_id"]) + run["start_time"] = str(run["start_time"]) + run["last_updated"] = str(run["last_updated"]) + return run @exception_view_config(HTTPUnauthorized) def authentication_failed(error, request): - response = Response(json_body=error.detail) - response.status_int = 401 - return response + response = Response(json_body=error.detail) + response.status_int = 401 + return response -@view_defaults(renderer='json') +@view_defaults(renderer="json") class ApiView(object): - ''' All API endpoints that require authentication are used by workers ''' - - def __init__(self, request): - self.request = request - - def require_authentication(self): - token = self.request.userdb.authenticate(self.get_username(), - self.request.json_body['password']) - if 'error' in token: - raise HTTPUnauthorized(token) - - def get_username(self): - if 'username' in self.request.json_body: - return self.request.json_body['username'] - return self.request.json_body['worker_info']['username'] - - def get_flag(self): - ip = self.request.remote_addr - if ip in flag_cache: - return flag_cache.get(ip, None) # Handle race condition on "del" - # concurrent invocations get None, race condition is not an issue - flag_cache[ip] = None - result = self.request.userdb.flag_cache.find_one({'ip': ip}) - if result: - flag_cache[ip] = result['country_code'] - return result['country_code'] - try: - # Get country flag from worker IP address - FLAG_HOST = 'https://freegeoip.app/json/' - r = requests.get(FLAG_HOST + self.request.remote_addr, timeout=1.0) - if r.status_code == 200: - country_code = r.json()['country_code'] - self.request.userdb.flag_cache.insert_one({ - 'ip': ip, - 'country_code': country_code, - 'geoip_checked_at': datetime.utcnow() - }) - flag_cache[ip] = country_code - return country_code - raise Error("flag server failed") - except: - del flag_cache[ip] - print('Failed GeoIP check for {}'.format(ip)) - return None - - def run_id(self): - return str(self.request.json_body['run_id']) - - def task_id(self): - return int(self.request.json_body['task_id']) - - - @view_config(route_name='api_active_runs') - def active_runs(self): - active = {} - for run in self.request.rundb.get_unfinished_runs(): - active[str(run['_id'])] = strip_run(run) - return active - - - @view_config(route_name='api_get_run') - def get_run(self): - run = self.request.rundb.get_run(self.request.matchdict['id']) - return strip_run(run) - - - @view_config(route_name='api_get_elo') - def get_elo(self): - run = self.request.rundb.get_run(self.request.matchdict['id']).copy() - results = run['results'] - if 'sprt' not in run['args']: - return {} - sprt = run['args'].get('sprt').copy() - elo_model = sprt.get('elo_model', 'BayesElo') - alpha = sprt['alpha'] - beta = sprt['beta'] - elo0 = sprt['elo0'] - elo1 = sprt['elo1'] - sprt['elo_model'] = elo_model - a = SPRT_elo(results, - alpha=alpha, beta=beta, - elo0=elo0, elo1=elo1, - elo_model=elo_model) - run = strip_run(run) - run['elo'] = a - run['args']['sprt'] = sprt - return run - - - @view_config(route_name='api_request_task') - def request_task(self): - self.require_authentication() - - worker_info = self.request.json_body['worker_info'] - worker_info['remote_addr'] = self.request.remote_addr - flag = self.get_flag() - if flag: - worker_info['country_code'] = flag - - result = self.request.rundb.request_task(worker_info) - if 'task_waiting' in result: - return result - - # Strip the run of unneccesary information - run = result['run'] - min_run = { - '_id': str(run['_id']), - 'args': run['args'], - 'tasks': [], - } - if int(str(worker_info['version']).split(':')[0]) > 64: - task = run['tasks'][result['task_id']] - min_task = {'num_games': task['num_games']} - if 'stats' in task: - min_task['stats'] = task['stats'] - min_run['my_task'] = min_task - else: - for task in run['tasks']: - min_task = {'num_games': task['num_games']} - if 'stats' in task: - min_task['stats'] = task['stats'] - min_run['tasks'].append(min_task) - - result['run'] = min_run - return result - - - @view_config(route_name='api_update_task') - def update_task(self): - self.require_authentication() - return self.request.rundb.update_task( - run_id=self.run_id(), - task_id=self.task_id(), - stats=self.request.json_body['stats'], - nps=self.request.json_body.get('nps', 0), - ARCH=self.request.json_body.get('ARCH', '?'), - spsa=self.request.json_body.get('spsa', {}), - username=self.get_username() - ) - - - @view_config(route_name='api_failed_task') - def failed_task(self): - self.require_authentication() - return self.request.rundb.failed_task(self.run_id(), self.task_id()) - - - @view_config(route_name='api_upload_pgn') - def upload_pgn(self): - self.require_authentication() - return self.request.rundb.upload_pgn( - run_id='{}-{}'.format(self.run_id(), self.task_id()), - pgn_zip=base64.b64decode(self.request.json_body['pgn']) - ) - - - @view_config(route_name='api_download_pgn', renderer='string') - def download_pgn(self): - pgn = self.request.rundb.get_pgn(self.request.matchdict['id']) - if pgn is None: - raise exception_response(404) - if '.pgn' in self.request.matchdict['id']: - self.request.response.content_type = 'application/x-chess-pgn' - return pgn - - - @view_config(route_name='api_download_pgn_100') - def download_pgn_100(self): - skip = int(self.request.matchdict['skip']) - urls = self.request.rundb.get_pgn_100(skip) - if urls is None: - raise exception_response(404) - return urls - - - @view_config(route_name='api_download_nn') - def download_nn(self): - nn = self.request.rundb.get_nn(self.request.matchdict['id']) - if nn is None: - raise exception_response(404) - #self.request.response.content_type = 'application/x-chess-nnue' - #self.request.response.body = zlib.decompress(nn['nn']) - #return self.request.response - return HTTPFound("https://data.stockfishchess.org/nn/" - + self.request.matchdict['id']) - - - @view_config(route_name='api_stop_run') - def stop_run(self): - self.require_authentication() - username = self.get_username() - user = self.request.userdb.user_cache.find_one({'username': username}) - if not user or user['cpu_hours'] < 1000: - return {} - with self.request.rundb.active_run_lock(self.run_id()): - run = self.request.rundb.get_run(self.run_id()) - run['finished'] = True - run['failed'] = True - run['stop_reason'] = self.request.json_body.get('message', 'API request') - self.request.actiondb.stop_run(username, run) - self.request.rundb.stop_run(self.run_id()) - return {} - - - @view_config(route_name='api_request_version') - def request_version(self): - self.require_authentication() - return {'version': WORKER_VERSION} - - - @view_config(route_name='api_request_spsa') - def request_spsa(self): - self.require_authentication() - return self.request.rundb.request_spsa(self.run_id(), self.task_id()) + """ All API endpoints that require authentication are used by workers """ + + def __init__(self, request): + self.request = request + + def require_authentication(self): + token = self.request.userdb.authenticate( + self.get_username(), self.request.json_body["password"] + ) + if "error" in token: + raise HTTPUnauthorized(token) + + def get_username(self): + if "username" in self.request.json_body: + return self.request.json_body["username"] + return self.request.json_body["worker_info"]["username"] + + def get_flag(self): + ip = self.request.remote_addr + if ip in flag_cache: + return flag_cache.get(ip, None) # Handle race condition on "del" + # concurrent invocations get None, race condition is not an issue + flag_cache[ip] = None + result = self.request.userdb.flag_cache.find_one({"ip": ip}) + if result: + flag_cache[ip] = result["country_code"] + return result["country_code"] + try: + # Get country flag from worker IP address + FLAG_HOST = "https://freegeoip.app/json/" + r = requests.get(FLAG_HOST + self.request.remote_addr, timeout=1.0) + if r.status_code == 200: + country_code = r.json()["country_code"] + self.request.userdb.flag_cache.insert_one( + { + "ip": ip, + "country_code": country_code, + "geoip_checked_at": datetime.utcnow(), + } + ) + flag_cache[ip] = country_code + return country_code + raise Error("flag server failed") + except: + del flag_cache[ip] + print("Failed GeoIP check for {}".format(ip)) + return None + + def run_id(self): + return str(self.request.json_body["run_id"]) + + def task_id(self): + return int(self.request.json_body["task_id"]) + + @view_config(route_name="api_active_runs") + def active_runs(self): + active = {} + for run in self.request.rundb.get_unfinished_runs(): + active[str(run["_id"])] = strip_run(run) + return active + + @view_config(route_name="api_get_run") + def get_run(self): + run = self.request.rundb.get_run(self.request.matchdict["id"]) + return strip_run(run) + + @view_config(route_name="api_get_elo") + def get_elo(self): + run = self.request.rundb.get_run(self.request.matchdict["id"]).copy() + results = run["results"] + if "sprt" not in run["args"]: + return {} + sprt = run["args"].get("sprt").copy() + elo_model = sprt.get("elo_model", "BayesElo") + alpha = sprt["alpha"] + beta = sprt["beta"] + elo0 = sprt["elo0"] + elo1 = sprt["elo1"] + sprt["elo_model"] = elo_model + a = SPRT_elo( + results, alpha=alpha, beta=beta, elo0=elo0, elo1=elo1, elo_model=elo_model + ) + run = strip_run(run) + run["elo"] = a + run["args"]["sprt"] = sprt + return run + + @view_config(route_name="api_request_task") + def request_task(self): + self.require_authentication() + + worker_info = self.request.json_body["worker_info"] + worker_info["remote_addr"] = self.request.remote_addr + flag = self.get_flag() + if flag: + worker_info["country_code"] = flag + + result = self.request.rundb.request_task(worker_info) + if "task_waiting" in result: + return result + + # Strip the run of unneccesary information + run = result["run"] + min_run = {"_id": str(run["_id"]), "args": run["args"], "tasks": []} + if int(str(worker_info["version"]).split(":")[0]) > 64: + task = run["tasks"][result["task_id"]] + min_task = {"num_games": task["num_games"]} + if "stats" in task: + min_task["stats"] = task["stats"] + min_run["my_task"] = min_task + else: + for task in run["tasks"]: + min_task = {"num_games": task["num_games"]} + if "stats" in task: + min_task["stats"] = task["stats"] + min_run["tasks"].append(min_task) + + result["run"] = min_run + return result + + @view_config(route_name="api_update_task") + def update_task(self): + self.require_authentication() + return self.request.rundb.update_task( + run_id=self.run_id(), + task_id=self.task_id(), + stats=self.request.json_body["stats"], + nps=self.request.json_body.get("nps", 0), + ARCH=self.request.json_body.get("ARCH", "?"), + spsa=self.request.json_body.get("spsa", {}), + username=self.get_username(), + ) + + @view_config(route_name="api_failed_task") + def failed_task(self): + self.require_authentication() + return self.request.rundb.failed_task(self.run_id(), self.task_id()) + + @view_config(route_name="api_upload_pgn") + def upload_pgn(self): + self.require_authentication() + return self.request.rundb.upload_pgn( + run_id="{}-{}".format(self.run_id(), self.task_id()), + pgn_zip=base64.b64decode(self.request.json_body["pgn"]), + ) + + @view_config(route_name="api_download_pgn", renderer="string") + def download_pgn(self): + pgn = self.request.rundb.get_pgn(self.request.matchdict["id"]) + if pgn is None: + raise exception_response(404) + if ".pgn" in self.request.matchdict["id"]: + self.request.response.content_type = "application/x-chess-pgn" + return pgn + + @view_config(route_name="api_download_pgn_100") + def download_pgn_100(self): + skip = int(self.request.matchdict["skip"]) + urls = self.request.rundb.get_pgn_100(skip) + if urls is None: + raise exception_response(404) + return urls + + @view_config(route_name="api_download_nn") + def download_nn(self): + nn = self.request.rundb.get_nn(self.request.matchdict["id"]) + if nn is None: + raise exception_response(404) + # self.request.response.content_type = 'application/x-chess-nnue' + # self.request.response.body = zlib.decompress(nn['nn']) + # return self.request.response + return HTTPFound( + "https://data.stockfishchess.org/nn/" + self.request.matchdict["id"] + ) + + @view_config(route_name="api_stop_run") + def stop_run(self): + self.require_authentication() + username = self.get_username() + user = self.request.userdb.user_cache.find_one({"username": username}) + if not user or user["cpu_hours"] < 1000: + return {} + with self.request.rundb.active_run_lock(self.run_id()): + run = self.request.rundb.get_run(self.run_id()) + run["finished"] = True + run["failed"] = True + run["stop_reason"] = self.request.json_body.get("message", "API request") + self.request.actiondb.stop_run(username, run) + self.request.rundb.stop_run(self.run_id()) + return {} + + @view_config(route_name="api_request_version") + def request_version(self): + self.require_authentication() + return {"version": WORKER_VERSION} + + @view_config(route_name="api_request_spsa") + def request_spsa(self): + self.require_authentication() + return self.request.rundb.request_spsa(self.run_id(), self.task_id()) diff --git a/fishtest/fishtest/helpers.py b/fishtest/fishtest/helpers.py index 7e14c0026..7c5e87678 100644 --- a/fishtest/fishtest/helpers.py +++ b/fishtest/fishtest/helpers.py @@ -1,17 +1,21 @@ def tests_repo(run): - return run['args'].get('tests_repo', 'https://github.com/official-stockfish/Stockfish') + return run["args"].get( + "tests_repo", "https://github.com/official-stockfish/Stockfish" + ) + def master_diff_url(run): - return "https://github.com/official-stockfish/Stockfish/compare/master...{}".format( - run['args']['resolved_base'][:10] - ) + return "https://github.com/official-stockfish/Stockfish/compare/master...{}".format( + run["args"]["resolved_base"][:10] + ) + def diff_url(run): - if run['args'].get('spsa'): - return master_diff_url(run) - else: - return "{}/compare/{}...{}".format( - tests_repo(run), - run['args']['resolved_base'][:10], - run['args']['resolved_new'][:10] - ) + if run["args"].get("spsa"): + return master_diff_url(run) + else: + return "{}/compare/{}...{}".format( + tests_repo(run), + run["args"]["resolved_base"][:10], + run["args"]["resolved_new"][:10], + ) diff --git a/fishtest/fishtest/models.py b/fishtest/fishtest/models.py index 07ae5b19c..79779f8f8 100644 --- a/fishtest/fishtest/models.py +++ b/fishtest/fishtest/models.py @@ -1,9 +1,8 @@ from pyramid.security import Allow, Everyone + class RootFactory(object): - __acl__ = [(Allow, Everyone, 'view'), - (Allow, 'group:approvers', 'approve_run') - ] - def __init__(self, request): - pass - \ No newline at end of file + __acl__ = [(Allow, Everyone, "view"), (Allow, "group:approvers", "approve_run")] + + def __init__(self, request): + pass diff --git a/fishtest/fishtest/rundb.py b/fishtest/fishtest/rundb.py index eb3083bbb..84e30f82f 100644 --- a/fishtest/fishtest/rundb.py +++ b/fishtest/fishtest/rundb.py @@ -19,879 +19,957 @@ import fishtest.stats.stat_util from fishtest.util import ( - calculate_residuals, - estimate_game_duration, - format_results, - post_in_fishcooking_results, - remaining_hours + calculate_residuals, + estimate_game_duration, + format_results, + post_in_fishcooking_results, + remaining_hours, ) last_rundb = None class RunDb: - def __init__(self, db_name='fishtest_new'): - # MongoDB server is assumed to be on the same machine, if not user should - # use ssh with port forwarding to access the remote host. - self.conn = MongoClient(os.getenv('FISHTEST_HOST') or 'localhost') - self.db = self.conn[db_name] - self.userdb = UserDb(self.db) - self.actiondb = ActionDb(self.db) - self.pgndb = self.db['pgns'] - self.nndb = self.db['nns'] - self.runs = self.db['runs'] - self.deltas = self.db['deltas'] - - self.chunk_size = 200 - - global last_rundb - last_rundb = self - - def generate_tasks(self, num_games): - tasks = [] - remaining = num_games - while remaining > 0: - task_size = min(self.chunk_size, remaining) - tasks.append({ - 'num_games': task_size, - 'pending': True, - 'active': False, - }) - remaining -= task_size - return tasks - - def new_run(self, base_tag, new_tag, num_games, tc, book, book_depth, - threads, base_options, new_options, - info='', - resolved_base='', - resolved_new='', - msg_base='', - msg_new='', - base_signature='', - new_signature='', - base_net=None, - new_net=None, - rescheduled_from=None, - base_same_as_master=None, - start_time=None, - sprt=None, - spsa=None, - username=None, - tests_repo=None, - auto_purge=False, - throughput=100, - priority=0): - if start_time is None: - start_time = datetime.utcnow() - - run_args = { - 'base_tag': base_tag, - 'new_tag': new_tag, - 'base_net': base_net, - 'new_net': new_net, - 'num_games': num_games, - 'tc': tc, - 'book': book, - 'book_depth': book_depth, - 'threads': threads, - 'resolved_base': resolved_base, - 'resolved_new': resolved_new, - 'msg_base': msg_base, - 'msg_new': msg_new, - 'base_options': base_options, - 'new_options': new_options, - 'info': info, - 'base_signature': base_signature, - 'new_signature': new_signature, - 'username': username, - 'tests_repo': tests_repo, - 'auto_purge': auto_purge, - 'throughput': throughput, - 'itp': 100, # internal throughput - 'priority': priority, - } - - if sprt is not None: - run_args['sprt'] = sprt - - if spsa is not None: - run_args['spsa'] = spsa - - tc_base = re.search('^(\d+(\.\d+)?)', tc) - if tc_base: - tc_base = float(tc_base.group(1)) - new_run = { - 'args': run_args, - 'start_time': start_time, - 'last_updated': start_time, - 'tc_base': tc_base, - 'base_same_as_master': base_same_as_master, - # Will be filled in by tasks, indexed by task-id - 'tasks': self.generate_tasks(num_games), - # Aggregated results - 'results': {'wins': 0, 'losses': 0, 'draws': 0}, - 'results_stale': False, - 'finished': False, - 'approved': False, - 'approver': '', - } - - if rescheduled_from: - new_run['rescheduled_from'] = rescheduled_from - - return self.runs.insert_one(new_run).inserted_id - - def get_machines(self): - machines = [] - active_runs = self.runs.find({ - 'finished': False, - 'tasks': { - '$elemMatch': {'active': True} - } - }, sort=[('last_updated', DESCENDING)]) - for run in active_runs: - for task in run['tasks']: - if task['active']: - machine = copy.copy(task['worker_info']) - machine['last_updated'] = task.get('last_updated', None) - machine['run'] = run - machine['nps'] = task.get('nps', 0) - machines.append(machine) - return machines - - def get_pgn(self, pgn_id): - pgn_id = pgn_id.split('.')[0] # strip .pgn - pgn = self.pgndb.find_one({'run_id': pgn_id}) - if pgn: - return zlib.decompress(pgn['pgn_zip']).decode() - return None - - def get_pgn_100(self, skip): - return [p['run_id'] for p in - self.pgndb.find(skip=skip, limit=100, sort=[('_id', DESCENDING)])] - - def upload_nn(self, userid, name, nn): - self.nndb.insert_one({'user': userid, 'name': name, 'downloads': 0}) - # 'nn': Binary(zlib.compress(nn))}) - return {} - - def update_nn(self, net): - net.pop('downloads', None) - self.nndb.update_one({'name': net['name']}, {'$set': net}) - - def get_nn(self, name): - # nn = self.nndb.find_one({'name': name}) - nn = self.nndb.find_one({'name': name}, {'nn': 0}) - if nn: - self.nndb.update_one({'name': name}, {'$inc': {'downloads': 1}}) - return nn - return None - - def get_nns(self, limit): - return [dict(n, time=n['_id'].generation_time) for n in - self.nndb.find({}, {'nn': 0}, - limit=limit, sort=[('_id', DESCENDING)])] - - - # Cache runs - run_cache = {} - run_cache_lock = threading.Lock() - run_cache_write_lock = threading.Lock() - - timer = None - - # handle termination - def exit_run(signum, frame): - global last_rundb - if last_rundb: - last_rundb.flush_all() - sys.exit(0) - - signal.signal(signal.SIGINT, exit_run) - signal.signal(signal.SIGTERM, exit_run) - - def get_run(self, r_id): - with self.run_cache_lock: - r_id = str(r_id) - if r_id in self.run_cache: - self.run_cache[r_id]['rtime'] = time.time() - return self.run_cache[r_id]['run'] - try: - run = self.runs.find_one({'_id': ObjectId(r_id)}) - if run: - self.run_cache[r_id] = {'rtime': time.time(), 'ftime': time.time(), - 'run': run, 'dirty': False} - return run - except: + def __init__(self, db_name="fishtest_new"): + # MongoDB server is assumed to be on the same machine, if not user should + # use ssh with port forwarding to access the remote host. + self.conn = MongoClient(os.getenv("FISHTEST_HOST") or "localhost") + self.db = self.conn[db_name] + self.userdb = UserDb(self.db) + self.actiondb = ActionDb(self.db) + self.pgndb = self.db["pgns"] + self.nndb = self.db["nns"] + self.runs = self.db["runs"] + self.deltas = self.db["deltas"] + + self.chunk_size = 200 + + global last_rundb + last_rundb = self + + def generate_tasks(self, num_games): + tasks = [] + remaining = num_games + while remaining > 0: + task_size = min(self.chunk_size, remaining) + tasks.append({"num_games": task_size, "pending": True, "active": False}) + remaining -= task_size + return tasks + + def new_run( + self, + base_tag, + new_tag, + num_games, + tc, + book, + book_depth, + threads, + base_options, + new_options, + info="", + resolved_base="", + resolved_new="", + msg_base="", + msg_new="", + base_signature="", + new_signature="", + base_net=None, + new_net=None, + rescheduled_from=None, + base_same_as_master=None, + start_time=None, + sprt=None, + spsa=None, + username=None, + tests_repo=None, + auto_purge=False, + throughput=100, + priority=0, + ): + if start_time is None: + start_time = datetime.utcnow() + + run_args = { + "base_tag": base_tag, + "new_tag": new_tag, + "base_net": base_net, + "new_net": new_net, + "num_games": num_games, + "tc": tc, + "book": book, + "book_depth": book_depth, + "threads": threads, + "resolved_base": resolved_base, + "resolved_new": resolved_new, + "msg_base": msg_base, + "msg_new": msg_new, + "base_options": base_options, + "new_options": new_options, + "info": info, + "base_signature": base_signature, + "new_signature": new_signature, + "username": username, + "tests_repo": tests_repo, + "auto_purge": auto_purge, + "throughput": throughput, + "itp": 100, # internal throughput + "priority": priority, + } + + if sprt is not None: + run_args["sprt"] = sprt + + if spsa is not None: + run_args["spsa"] = spsa + + tc_base = re.search("^(\d+(\.\d+)?)", tc) + if tc_base: + tc_base = float(tc_base.group(1)) + new_run = { + "args": run_args, + "start_time": start_time, + "last_updated": start_time, + "tc_base": tc_base, + "base_same_as_master": base_same_as_master, + # Will be filled in by tasks, indexed by task-id + "tasks": self.generate_tasks(num_games), + # Aggregated results + "results": {"wins": 0, "losses": 0, "draws": 0}, + "results_stale": False, + "finished": False, + "approved": False, + "approver": "", + } + + if rescheduled_from: + new_run["rescheduled_from"] = rescheduled_from + + return self.runs.insert_one(new_run).inserted_id + + def get_machines(self): + machines = [] + active_runs = self.runs.find( + {"finished": False, "tasks": {"$elemMatch": {"active": True}}}, + sort=[("last_updated", DESCENDING)], + ) + for run in active_runs: + for task in run["tasks"]: + if task["active"]: + machine = copy.copy(task["worker_info"]) + machine["last_updated"] = task.get("last_updated", None) + machine["run"] = run + machine["nps"] = task.get("nps", 0) + machines.append(machine) + return machines + + def get_pgn(self, pgn_id): + pgn_id = pgn_id.split(".")[0] # strip .pgn + pgn = self.pgndb.find_one({"run_id": pgn_id}) + if pgn: + return zlib.decompress(pgn["pgn_zip"]).decode() return None - def start_timer(self): - self.timer = threading.Timer(1.0, self.flush_buffers) - self.timer.start() + def get_pgn_100(self, skip): + return [ + p["run_id"] + for p in self.pgndb.find(skip=skip, limit=100, sort=[("_id", DESCENDING)]) + ] + + def upload_nn(self, userid, name, nn): + self.nndb.insert_one({"user": userid, "name": name, "downloads": 0}) + # 'nn': Binary(zlib.compress(nn))}) + return {} + + def update_nn(self, net): + net.pop("downloads", None) + self.nndb.update_one({"name": net["name"]}, {"$set": net}) + + def get_nn(self, name): + # nn = self.nndb.find_one({'name': name}) + nn = self.nndb.find_one({"name": name}, {"nn": 0}) + if nn: + self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}}) + return nn + return None - def buffer(self, run, flush): - with self.run_cache_lock: - if self.timer is None: + def get_nns(self, limit): + return [ + dict(n, time=n["_id"].generation_time) + for n in self.nndb.find( + {}, {"nn": 0}, limit=limit, sort=[("_id", DESCENDING)] + ) + ] + + # Cache runs + run_cache = {} + run_cache_lock = threading.Lock() + run_cache_write_lock = threading.Lock() + + timer = None + + # handle termination + def exit_run(signum, frame): + global last_rundb + if last_rundb: + last_rundb.flush_all() + sys.exit(0) + + signal.signal(signal.SIGINT, exit_run) + signal.signal(signal.SIGTERM, exit_run) + + def get_run(self, r_id): + with self.run_cache_lock: + r_id = str(r_id) + if r_id in self.run_cache: + self.run_cache[r_id]["rtime"] = time.time() + return self.run_cache[r_id]["run"] + try: + run = self.runs.find_one({"_id": ObjectId(r_id)}) + if run: + self.run_cache[r_id] = { + "rtime": time.time(), + "ftime": time.time(), + "run": run, + "dirty": False, + } + return run + except: + return None + + def start_timer(self): + self.timer = threading.Timer(1.0, self.flush_buffers) + self.timer.start() + + def buffer(self, run, flush): + with self.run_cache_lock: + if self.timer is None: + self.start_timer() + r_id = str(run["_id"]) + if flush: + self.run_cache[r_id] = { + "dirty": False, + "rtime": time.time(), + "ftime": time.time(), + "run": run, + } + with self.run_cache_write_lock: + self.runs.replace_one({"_id": ObjectId(r_id)}, run) + else: + if r_id in self.run_cache: + ftime = self.run_cache[r_id]["ftime"] + else: + ftime = time.time() + self.run_cache[r_id] = { + "dirty": True, + "rtime": time.time(), + "ftime": ftime, + "run": run, + } + + def stop(self): + self.flush_all() + with self.run_cache_lock: + self.timer = None + time.sleep(1.1) + + def flush_all(self): + print("flush") + # Note that we do not grab locks because this method is + # called from a signal handler and grabbing locks might deadlock + for r_id in list(self.run_cache): + if self.run_cache[r_id]["dirty"]: + self.runs.replace_one( + {"_id": ObjectId(r_id)}, self.run_cache[r_id]["run"] + ) + print(".", end="") + print("done") + + def flush_buffers(self): + if self.timer is None: + return + self.run_cache_lock.acquire() + now = time.time() + old = now + 1 + oldest = None + for r_id in list(self.run_cache): + if not self.run_cache[r_id]["dirty"]: + if self.run_cache[r_id]["rtime"] < now - 60: + del self.run_cache[r_id] + elif self.run_cache[r_id]["ftime"] < old: + old = self.run_cache[r_id]["ftime"] + oldest = r_id + if oldest is not None: + if int(now) % 60 == 0: + self.scavenge(self.run_cache[oldest]["run"]) + self.run_cache[oldest]["dirty"] = False + self.run_cache[oldest]["ftime"] = time.time() + self.run_cache_lock.release() # Release the lock while writing + # print("SYNC") + with self.run_cache_write_lock: + self.runs.save(self.run_cache[oldest]["run"]) + # start the timer when writing is done + self.start_timer() + return + # Nothing to flush, start timer: self.start_timer() - r_id = str(run['_id']) - if flush: - self.run_cache[r_id] = {'dirty': False, 'rtime': time.time(), - 'ftime': time.time(), 'run': run} + self.run_cache_lock.release() + + def scavenge(self, run): + old = datetime.utcnow() - timedelta(minutes=30) + for task in run["tasks"]: + if task["active"] and task["last_updated"] < old: + task["active"] = False + + def get_unfinished_runs_id(self): + with self.run_cache_write_lock: + unfinished_runs = self.runs.find( + {"finished": False}, {"_id": 1}, sort=[("last_updated", DESCENDING)] + ) + return unfinished_runs + + def get_unfinished_runs(self, username=None): with self.run_cache_write_lock: - self.runs.replace_one({ '_id': ObjectId(r_id) }, run) - else: - if r_id in self.run_cache: - ftime = self.run_cache[r_id]['ftime'] + unfinished_runs = self.runs.find( + {"finished": False}, sort=[("last_updated", DESCENDING)] + ) + if username: + unfinished_runs = [ + r for r in unfinished_runs if r["args"].get("username") == username + ] + return unfinished_runs + + def aggregate_unfinished_runs(self, username=None): + unfinished_runs = self.get_unfinished_runs(username) + runs = {"pending": [], "active": []} + for run in unfinished_runs: + state = ( + "active" if any(task["active"] for task in run["tasks"]) else "pending" + ) + runs[state].append(run) + runs["pending"].sort( + key=lambda run: ( + run["args"]["priority"], + run["args"]["itp"] if "itp" in run["args"] else 100, + ) + ) + runs["active"].sort( + reverse=True, + key=lambda run: ( + "sprt" in run["args"], + run["args"].get("sprt", {}).get("llr", 0), + "spsa" not in run["args"], + run["results"]["wins"] + + run["results"]["draws"] + + run["results"]["losses"], + ), + ) + + # Calculate but don't save results_info on runs using info on current machines + cores = 0 + nps = 0 + for m in self.get_machines(): + concurrency = int(m["concurrency"]) + cores += concurrency + nps += concurrency * m["nps"] + pending_hours = 0 + for run in runs["pending"] + runs["active"]: + if cores > 0: + eta = remaining_hours(run) / cores + pending_hours += eta + results = self.get_results(run, False) + run["results_info"] = format_results(results, run) + if "Pending..." in run["results_info"]["info"]: + if cores > 0: + run["results_info"]["info"][0] += " (%.1f hrs)" % (eta) + if "sprt" in run["args"]: + sprt = run["args"]["sprt"] + elo_model = sprt.get("elo_model", "BayesElo") + if elo_model == "BayesElo": + run["results_info"]["info"].append( + ("[%.2f,%.2f]") % (sprt["elo0"], sprt["elo1"]) + ) + else: + run["results_info"]["info"].append( + ("{%.2f,%.2f}") % (sprt["elo0"], sprt["elo1"]) + ) + return (runs, pending_hours, cores, nps) + + def get_finished_runs( + self, + skip=0, + limit=0, + username="", + success_only=False, + yellow_only=False, + ltc_only=False, + ): + q = {"finished": True} + idx_hint = "finished_runs" + if username: + q["args.username"] = username + idx_hint = None + if ltc_only: + q["tc_base"] = {"$gte": 40} + idx_hint = "finished_ltc_runs" + if success_only: + q["is_green"] = True + idx_hint = "finished_green_runs" + if yellow_only: + q["is_yellow"] = True + idx_hint = "finished_yellow_runs" + + c = self.runs.find( + q, skip=skip, limit=limit, sort=[("last_updated", DESCENDING)] + ) + + if idx_hint: + # Use a fast COUNT_SCAN query when possible + count = self.runs.estimated_document_count(hint=idx_hint) else: - ftime = time.time() - self.run_cache[r_id] = {'dirty': True, 'rtime': time.time(), - 'ftime': ftime, 'run': run} - - def stop(self): - self.flush_all() - with self.run_cache_lock: - self.timer = None - time.sleep(1.1) - - def flush_all(self): - print("flush") - # Note that we do not grab locks because this method is - # called from a signal handler and grabbing locks might deadlock - for r_id in list(self.run_cache): - if self.run_cache[r_id]['dirty']: - self.runs.replace_one({ '_id': ObjectId(r_id) }, self.run_cache[r_id]['run']) - print(".", end='') - print("done") - - def flush_buffers(self): - if self.timer is None: - return - self.run_cache_lock.acquire() - now = time.time() - old = now + 1 - oldest = None - for r_id in list(self.run_cache): - if not self.run_cache[r_id]['dirty']: - if self.run_cache[r_id]['rtime'] < now - 60: - del self.run_cache[r_id] - elif self.run_cache[r_id]['ftime'] < old: - old = self.run_cache[r_id]['ftime'] - oldest = r_id - if oldest is not None: - if int(now) % 60 == 0: - self.scavenge(self.run_cache[oldest]['run']) - self.run_cache[oldest]['dirty'] = False - self.run_cache[oldest]['ftime'] = time.time() - self.run_cache_lock.release() # Release the lock while writing - # print("SYNC") - with self.run_cache_write_lock: - self.runs.save(self.run_cache[oldest]['run']) - # start the timer when writing is done - self.start_timer() - return - # Nothing to flush, start timer: - self.start_timer() - self.run_cache_lock.release() - - def scavenge(self, run): - old = datetime.utcnow() - timedelta(minutes=30) - for task in run['tasks']: - if task['active'] and task['last_updated'] < old: - task['active'] = False - - def get_unfinished_runs_id(self): - with self.run_cache_write_lock: - unfinished_runs = self.runs.find({'finished': False}, - {'_id': 1}, - sort=[('last_updated', DESCENDING)]) - return unfinished_runs - - def get_unfinished_runs(self, username=None): - with self.run_cache_write_lock: - unfinished_runs = self.runs.find({'finished': False}, - sort=[('last_updated', DESCENDING)]) - if username: - unfinished_runs = [r for r in unfinished_runs if r['args'].get('username') == username] - return unfinished_runs - - def aggregate_unfinished_runs(self, username=None): - unfinished_runs = self.get_unfinished_runs(username) - runs = {'pending': [], 'active': []} - for run in unfinished_runs: - state = 'active' if any(task['active'] for task in run['tasks']) else 'pending' - runs[state].append(run) - runs['pending'].sort(key=lambda run: (run['args']['priority'], - run['args']['itp'] - if 'itp' in run['args'] else 100)) - runs['active'].sort(reverse=True, key=lambda run: ( - 'sprt' in run['args'], - run['args'].get('sprt',{}).get('llr',0), - 'spsa' not in run['args'], - run['results']['wins'] + run['results']['draws'] - + run['results']['losses'])) - - # Calculate but don't save results_info on runs using info on current machines - cores = 0 - nps = 0 - for m in self.get_machines(): - concurrency = int(m['concurrency']) - cores += concurrency - nps += concurrency * m['nps'] - pending_hours = 0 - for run in runs['pending'] + runs['active']: - if cores > 0: - eta = remaining_hours(run) / cores - pending_hours += eta - results = self.get_results(run, False) - run['results_info'] = format_results(results, run) - if 'Pending...' in run['results_info']['info']: - if cores > 0: - run['results_info']['info'][0] += ' (%.1f hrs)' % (eta) - if 'sprt' in run['args']: - sprt = run['args']['sprt'] - elo_model = sprt.get('elo_model', 'BayesElo') - if elo_model == 'BayesElo': - run['results_info']['info'].append(('[%.2f,%.2f]') - % (sprt['elo0'], sprt['elo1'])) - else: - run['results_info']['info'].append(('{%.2f,%.2f}') - % (sprt['elo0'], sprt['elo1'])) - return (runs, pending_hours, cores, nps) - - - def get_finished_runs(self, skip=0, limit=0, username='', - success_only=False, yellow_only=False, ltc_only=False): - q = {'finished': True} - idx_hint = 'finished_runs' - if username: - q['args.username'] = username - idx_hint = None - if ltc_only: - q['tc_base'] = {'$gte': 40} - idx_hint = 'finished_ltc_runs' - if success_only: - q['is_green'] = True - idx_hint = 'finished_green_runs' - if yellow_only: - q['is_yellow'] = True - idx_hint = 'finished_yellow_runs' - - c = self.runs.find(q, skip=skip, limit=limit, - sort=[('last_updated', DESCENDING)]) - - if idx_hint: - # Use a fast COUNT_SCAN query when possible - count = self.runs.estimated_document_count(hint=idx_hint) - else: - # Otherwise, the count is slow - count = c.count() - # Don't show runs that were deleted - runs_list = [run for run in c if not run.get('deleted')] - return [runs_list, count] - - def get_results(self, run, save_run=True): - if not run['results_stale']: - return run['results'] - - results = {'wins': 0, 'losses': 0, 'draws': 0, - 'crashes': 0, 'time_losses': 0} - - has_pentanomial = True - pentanomial = 5*[0] - for task in run['tasks']: - if 'stats' in task: - stats = task['stats'] - results['wins'] += stats['wins'] - results['losses'] += stats['losses'] - results['draws'] += stats['draws'] - results['crashes'] += stats['crashes'] - results['time_losses'] += stats.get('time_losses', 0) - if 'pentanomial' in stats.keys() and has_pentanomial: - pentanomial = [pentanomial[i]+stats['pentanomial'][i] - for i in range(0, 5)] + # Otherwise, the count is slow + count = c.count() + # Don't show runs that were deleted + runs_list = [run for run in c if not run.get("deleted")] + return [runs_list, count] + + def get_results(self, run, save_run=True): + if not run["results_stale"]: + return run["results"] + + results = {"wins": 0, "losses": 0, "draws": 0, "crashes": 0, "time_losses": 0} + + has_pentanomial = True + pentanomial = 5 * [0] + for task in run["tasks"]: + if "stats" in task: + stats = task["stats"] + results["wins"] += stats["wins"] + results["losses"] += stats["losses"] + results["draws"] += stats["draws"] + results["crashes"] += stats["crashes"] + results["time_losses"] += stats.get("time_losses", 0) + if "pentanomial" in stats.keys() and has_pentanomial: + pentanomial = [ + pentanomial[i] + stats["pentanomial"][i] for i in range(0, 5) + ] + else: + has_pentanomial = False + if has_pentanomial: + results["pentanomial"] = pentanomial + + run["results_stale"] = False + run["results"] = results + if save_run: + self.buffer(run, True) + + return results + + def calc_itp(self, run): + itp = run["args"]["throughput"] + if itp < 1: + itp = 1 + elif itp > 500: + itp = 500 + itp *= math.sqrt( + estimate_game_duration(run["args"]["tc"]) / estimate_game_duration("10+0.1") + ) + itp *= math.sqrt(run["args"]["threads"]) + if "sprt" not in run["args"]: + itp *= 0.5 else: - has_pentanomial = False - if has_pentanomial: - results['pentanomial'] = pentanomial - - run['results_stale'] = False - run['results'] = results - if save_run: - self.buffer(run, True) - - return results - - def calc_itp(self, run): - itp = run['args']['throughput'] - if itp < 1: - itp = 1 - elif itp > 500: - itp = 500 - itp *= math.sqrt(estimate_game_duration(run['args']['tc'])/estimate_game_duration('10+0.1')) - itp *= math.sqrt(run['args']['threads']) - if 'sprt' not in run['args']: - itp *= 0.5 - else: - llr = run['args']['sprt'].get('llr',0) - itp *= (5 + llr) / 5 - run['args']['itp'] = itp - - def sum_cores(self, run): - cores = 0 - for task in run['tasks']: - if task['active']: - cores += int(task['worker_info']['concurrency']) - run['cores'] = cores - - # Limit concurrent request_task - task_lock = threading.Lock() - task_semaphore = threading.Semaphore(4) - - task_time = 0 - task_runs = None - - worker_runs = {} - - def request_task(self, worker_info): - if self.task_semaphore.acquire(False): - try: - with self.task_lock: - return self.sync_request_task(worker_info) - finally: - self.task_semaphore.release() - else: - print("request_task too busy") - return {'task_waiting': False} - - def sync_request_task(self, worker_info): - if time.time() > self.task_time + 60: - self.task_runs = [] - for r in self.get_unfinished_runs_id(): - run = self.get_run(r['_id']) - self.sum_cores(run) - self.calc_itp(run) - self.task_runs.append(run) - self.task_runs.sort(key=lambda r: (-r['args']['priority'], - r['cores'] / r['args']['itp'] * 100.0, - -r['args']['itp'], r['_id'])) - self.task_time = time.time() - - max_threads = int(worker_info['concurrency']) - min_threads = int(worker_info.get('min_threads', 1)) - max_memory = int(worker_info.get('max_memory', 0)) - - # We need to allocate a new task, but first check we don't have the same - # machine already running because multiple connections are not allowed. - connections = 0 - for run in self.task_runs: - for task in run['tasks']: - if (task['active'] - and task['worker_info']['remote_addr'] - == worker_info['remote_addr']): - connections = connections + 1 - - # Allow a few connections, for multiple computers on same IP - if connections >= self.userdb.get_machine_limit(worker_info['username']): - return {'task_waiting': False, 'hit_machine_limit': True} - - # Limit worker Github API calls - if 'rate' in worker_info: - rate = worker_info['rate'] - limit = rate['remaining'] <= 2 * math.sqrt(rate['limit']) - else: - limit = False - worker_key = worker_info['unique_key'] - - # Get a new task that matches the worker requirements - run_found = False - for run in self.task_runs: - # compute required TT memory - need_tt = 0 - if max_memory > 0: - def get_hash(s): - h = re.search('Hash=([0-9]+)', s) - if h: - return int(h.group(1)) - return 0 - need_tt += get_hash(run['args']['new_options']) - need_tt += get_hash(run['args']['base_options']) - need_tt *= max_threads // run['args']['threads'] - - if run['approved'] \ - and (not limit or (worker_key in self.worker_runs - and run['_id'] in self.worker_runs[worker_key])) \ - and run['args']['threads'] <= max_threads \ - and run['args']['threads'] >= min_threads \ - and need_tt <= max_memory: - task_id = -1 + llr = run["args"]["sprt"].get("llr", 0) + itp *= (5 + llr) / 5 + run["args"]["itp"] = itp + + def sum_cores(self, run): cores = 0 - if 'spsa' in run['args']: - limit_cores = 40000 / math.sqrt(len(run['args']['spsa']['params'])) + for task in run["tasks"]: + if task["active"]: + cores += int(task["worker_info"]["concurrency"]) + run["cores"] = cores + + # Limit concurrent request_task + task_lock = threading.Lock() + task_semaphore = threading.Semaphore(4) + + task_time = 0 + task_runs = None + + worker_runs = {} + + def request_task(self, worker_info): + if self.task_semaphore.acquire(False): + try: + with self.task_lock: + return self.sync_request_task(worker_info) + finally: + self.task_semaphore.release() + else: + print("request_task too busy") + return {"task_waiting": False} + + def sync_request_task(self, worker_info): + if time.time() > self.task_time + 60: + self.task_runs = [] + for r in self.get_unfinished_runs_id(): + run = self.get_run(r["_id"]) + self.sum_cores(run) + self.calc_itp(run) + self.task_runs.append(run) + self.task_runs.sort( + key=lambda r: ( + -r["args"]["priority"], + r["cores"] / r["args"]["itp"] * 100.0, + -r["args"]["itp"], + r["_id"], + ) + ) + self.task_time = time.time() + + max_threads = int(worker_info["concurrency"]) + min_threads = int(worker_info.get("min_threads", 1)) + max_memory = int(worker_info.get("max_memory", 0)) + + # We need to allocate a new task, but first check we don't have the same + # machine already running because multiple connections are not allowed. + connections = 0 + for run in self.task_runs: + for task in run["tasks"]: + if ( + task["active"] + and task["worker_info"]["remote_addr"] == worker_info["remote_addr"] + ): + connections = connections + 1 + + # Allow a few connections, for multiple computers on same IP + if connections >= self.userdb.get_machine_limit(worker_info["username"]): + return {"task_waiting": False, "hit_machine_limit": True} + + # Limit worker Github API calls + if "rate" in worker_info: + rate = worker_info["rate"] + limit = rate["remaining"] <= 2 * math.sqrt(rate["limit"]) else: - limit_cores = 1000000 # No limit for SPRT - for task in run['tasks']: - if task['active']: - cores += task['worker_info']['concurrency'] - if cores > limit_cores: - break - task_id = task_id + 1 - if not task['active'] and task['pending']: - task['worker_info'] = worker_info - task['last_updated'] = datetime.utcnow() - task['active'] = True - run_found = True - break - if run_found: - break - - if not run_found: - return {'task_waiting': False} - - self.sum_cores(run) - self.task_runs.sort(key=lambda r: (-r['args']['priority'], - r['cores'] / r['args']['itp'] * 100.0, - -r['args']['itp'], r['_id'])) - - self.buffer(run, False) - - # Update worker_runs (compiled tests) - if worker_key not in self.worker_runs: - self.worker_runs[worker_key] = {} - if run['_id'] not in self.worker_runs[worker_key]: - self.worker_runs[worker_key][run['_id']] = True - - return {'run': run, 'task_id': task_id} - - # Create a lock for each active run - run_lock = threading.Lock() - active_runs = {} - purge_count = 0 - - def active_run_lock(self, id): - with self.run_lock: - self.purge_count = self.purge_count + 1 - if self.purge_count > 100000: - old = time.time() - 10000 - self.active_runs = dict( - (k, v) for k, v in self.active_runs.items() if v['time'] >= old) - self.purge_count = 0 - if id in self.active_runs: - active_lock = self.active_runs[id]['lock'] - self.active_runs[id]['time'] = time.time() - else: - active_lock = threading.Lock() - self.active_runs[id] = {'time': time.time(), 'lock': active_lock} - return active_lock - - def update_task(self, run_id, task_id, stats, nps, ARCH, spsa, username): - lock = self.active_run_lock(str(run_id)) - with lock: - return self.sync_update_task(run_id, task_id, stats, nps, ARCH, spsa, username) - - def sync_update_task(self, run_id, task_id, stats, nps, ARCH, spsa, username): - run = self.get_run(run_id) - if task_id >= len(run['tasks']): - return {'task_alive': False} - - task = run['tasks'][task_id] - if not task['active'] or not task['pending']: - return {'task_alive': False} - if task['worker_info']['username'] != username: - print('Update_task: Non matching username: ' + username) - return {'task_alive': False} - - # Guard against incorrect results - count_games = lambda d: d['wins'] + d['losses'] + d['draws'] - num_games = count_games(stats) - old_num_games = count_games(task['stats']) if 'stats' in task else 0 - spsa_games = count_games(spsa) if 'spsa' in run['args'] else 0 - if (num_games < old_num_games - or (spsa_games > 0 and num_games <= 0) - or (spsa_games > 0 and 'stats' in task and num_games <= old_num_games) + limit = False + worker_key = worker_info["unique_key"] + + # Get a new task that matches the worker requirements + run_found = False + for run in self.task_runs: + # compute required TT memory + need_tt = 0 + if max_memory > 0: + + def get_hash(s): + h = re.search("Hash=([0-9]+)", s) + if h: + return int(h.group(1)) + return 0 + + need_tt += get_hash(run["args"]["new_options"]) + need_tt += get_hash(run["args"]["base_options"]) + need_tt *= max_threads // run["args"]["threads"] + + if ( + run["approved"] + and ( + not limit + or ( + worker_key in self.worker_runs + and run["_id"] in self.worker_runs[worker_key] + ) + ) + and run["args"]["threads"] <= max_threads + and run["args"]["threads"] >= min_threads + and need_tt <= max_memory + ): + task_id = -1 + cores = 0 + if "spsa" in run["args"]: + limit_cores = 40000 / math.sqrt(len(run["args"]["spsa"]["params"])) + else: + limit_cores = 1000000 # No limit for SPRT + for task in run["tasks"]: + if task["active"]: + cores += task["worker_info"]["concurrency"] + if cores > limit_cores: + break + task_id = task_id + 1 + if not task["active"] and task["pending"]: + task["worker_info"] = worker_info + task["last_updated"] = datetime.utcnow() + task["active"] = True + run_found = True + break + if run_found: + break + + if not run_found: + return {"task_waiting": False} + + self.sum_cores(run) + self.task_runs.sort( + key=lambda r: ( + -r["args"]["priority"], + r["cores"] / r["args"]["itp"] * 100.0, + -r["args"]["itp"], + r["_id"], + ) + ) + + self.buffer(run, False) + + # Update worker_runs (compiled tests) + if worker_key not in self.worker_runs: + self.worker_runs[worker_key] = {} + if run["_id"] not in self.worker_runs[worker_key]: + self.worker_runs[worker_key][run["_id"]] = True + + return {"run": run, "task_id": task_id} + + # Create a lock for each active run + run_lock = threading.Lock() + active_runs = {} + purge_count = 0 + + def active_run_lock(self, id): + with self.run_lock: + self.purge_count = self.purge_count + 1 + if self.purge_count > 100000: + old = time.time() - 10000 + self.active_runs = dict( + (k, v) for k, v in self.active_runs.items() if v["time"] >= old + ) + self.purge_count = 0 + if id in self.active_runs: + active_lock = self.active_runs[id]["lock"] + self.active_runs[id]["time"] = time.time() + else: + active_lock = threading.Lock() + self.active_runs[id] = {"time": time.time(), "lock": active_lock} + return active_lock + + def update_task(self, run_id, task_id, stats, nps, ARCH, spsa, username): + lock = self.active_run_lock(str(run_id)) + with lock: + return self.sync_update_task( + run_id, task_id, stats, nps, ARCH, spsa, username + ) + + def sync_update_task(self, run_id, task_id, stats, nps, ARCH, spsa, username): + run = self.get_run(run_id) + if task_id >= len(run["tasks"]): + return {"task_alive": False} + + task = run["tasks"][task_id] + if not task["active"] or not task["pending"]: + return {"task_alive": False} + if task["worker_info"]["username"] != username: + print("Update_task: Non matching username: " + username) + return {"task_alive": False} + + # Guard against incorrect results + count_games = lambda d: d["wins"] + d["losses"] + d["draws"] + num_games = count_games(stats) + old_num_games = count_games(task["stats"]) if "stats" in task else 0 + spsa_games = count_games(spsa) if "spsa" in run["args"] else 0 + if ( + num_games < old_num_games + or (spsa_games > 0 and num_games <= 0) + or (spsa_games > 0 and "stats" in task and num_games <= old_num_games) ): - return {'task_alive': False} - if (num_games-old_num_games)%2!=0: # the worker should only runs game pairs - return {'task_alive': False} - if 'sprt' in run['args']: - batch_size=2*run['args']['sprt'].get('batch_size',1) - if num_games%batch_size != 0: - return {'task_alive': False} - - all_tasks_finished = False - - task['stats'] = stats - task['nps'] = nps - task['ARCH'] = ARCH - if num_games >= task['num_games']: - # This task is now finished - if 'cores' in run: - run['cores'] -= task['worker_info']['concurrency'] - task['pending'] = False # Make pending False before making active false - # to prevent race in request_task - task['active'] = False - # Check if all tasks in the run have been finished - if not any([t['pending'] or t['active'] for t in run['tasks']]): - all_tasks_finished = True - - update_time = datetime.utcnow() - task['last_updated'] = update_time - run['last_updated'] = update_time - run['results_stale'] = True - - # Update SPSA results - if 'spsa' in run['args'] and spsa_games == spsa['num_games']: - self.update_spsa(task['worker_info']['unique_key'], run, spsa) - - # Check SPRT state to decide whether to stop the run - if 'sprt' in run['args']: - sprt = run['args']['sprt'] - fishtest.stats.stat_util.update_SPRT(self.get_results(run, False), sprt) - if sprt['state'] != '': - # If SPRT is accepted or rejected, stop the run + return {"task_alive": False} + if ( + num_games - old_num_games + ) % 2 != 0: # the worker should only runs game pairs + return {"task_alive": False} + if "sprt" in run["args"]: + batch_size = 2 * run["args"]["sprt"].get("batch_size", 1) + if num_games % batch_size != 0: + return {"task_alive": False} + + all_tasks_finished = False + + task["stats"] = stats + task["nps"] = nps + task["ARCH"] = ARCH + if num_games >= task["num_games"]: + # This task is now finished + if "cores" in run: + run["cores"] -= task["worker_info"]["concurrency"] + task["pending"] = False # Make pending False before making active false + # to prevent race in request_task + task["active"] = False + # Check if all tasks in the run have been finished + if not any([t["pending"] or t["active"] for t in run["tasks"]]): + all_tasks_finished = True + + update_time = datetime.utcnow() + task["last_updated"] = update_time + run["last_updated"] = update_time + run["results_stale"] = True + + # Update SPSA results + if "spsa" in run["args"] and spsa_games == spsa["num_games"]: + self.update_spsa(task["worker_info"]["unique_key"], run, spsa) + + # Check SPRT state to decide whether to stop the run + if "sprt" in run["args"]: + sprt = run["args"]["sprt"] + fishtest.stats.stat_util.update_SPRT(self.get_results(run, False), sprt) + if sprt["state"] != "": + # If SPRT is accepted or rejected, stop the run + self.buffer(run, True) + self.stop_run(run_id) + return {"task_alive": False} + + if all_tasks_finished: + # If all tasks are finished, stop the run + self.buffer(run, True) + self.stop_run(run_id) + else: + self.buffer(run, False) + return {"task_alive": task["active"]} + + def upload_pgn(self, run_id, pgn_zip): + self.pgndb.insert_one({"run_id": run_id, "pgn_zip": Binary(pgn_zip)}) + return {} + + def failed_task(self, run_id, task_id): + run = self.get_run(run_id) + if task_id >= len(run["tasks"]): + return {"task_alive": False} + + task = run["tasks"][task_id] + if not task["active"] or not task["pending"]: + return {"task_alive": False} + + # Mark the task as inactive: it will be rescheduled + task["active"] = False self.buffer(run, True) - self.stop_run(run_id) - return {'task_alive': False} - - if all_tasks_finished: - # If all tasks are finished, stop the run - self.buffer(run, True) - self.stop_run(run_id) - else: - self.buffer(run, False) - return {'task_alive': task['active']} - - def upload_pgn(self, run_id, pgn_zip): - self.pgndb.insert_one({'run_id': run_id, 'pgn_zip': Binary(pgn_zip)}) - return {} - - def failed_task(self, run_id, task_id): - run = self.get_run(run_id) - if task_id >= len(run['tasks']): - return {'task_alive': False} - - task = run['tasks'][task_id] - if not task['active'] or not task['pending']: - return {'task_alive': False} - - # Mark the task as inactive: it will be rescheduled - task['active'] = False - self.buffer(run, True) - return {} - - def stop_run(self, run_id, run=None): - """ Stops a run and runs auto-purge if it was enabled + return {} + + def stop_run(self, run_id, run=None): + """ Stops a run and runs auto-purge if it was enabled - Used by the website and API for manually stopping runs - Called during /api/update_task: - for stopping SPRT runs if the test is accepted or rejected - for stopping a run after all games are finished """ - self.clear_params(run_id) - save_it = False - if run is None: - run = self.get_run(run_id) - save_it = True - run['tasks'] = [task for task in run['tasks'] if 'stats' in task] - for task in run['tasks']: - task['pending'] = False - task['active'] = False - if save_it: - self.buffer(run, True) - self.task_time = 0 - # Auto-purge runs here - purged = False - if run['args'].get('auto_purge', True) and 'spsa' not in run['args']: - if self.purge_run(run): - purged = True - run = self.get_run(run['_id']) - results = self.get_results(run, True) - run['results_info'] = format_results(results, run) - self.buffer(run, True) - if not purged: - # The run is now finished and will no longer be updated after this - run['finished'] = True - results = self.get_results(run, True) - run['results_info'] = format_results(results, run) - # De-couple the styling of the run from its finished status - if run['results_info']['style'] == '#44EB44': - run['is_green'] = True - elif run['results_info']['style'] == 'yellow': - run['is_yellow'] = True + self.clear_params(run_id) + save_it = False + if run is None: + run = self.get_run(run_id) + save_it = True + run["tasks"] = [task for task in run["tasks"] if "stats" in task] + for task in run["tasks"]: + task["pending"] = False + task["active"] = False + if save_it: + self.buffer(run, True) + self.task_time = 0 + # Auto-purge runs here + purged = False + if run["args"].get("auto_purge", True) and "spsa" not in run["args"]: + if self.purge_run(run): + purged = True + run = self.get_run(run["_id"]) + results = self.get_results(run, True) + run["results_info"] = format_results(results, run) + self.buffer(run, True) + if not purged: + # The run is now finished and will no longer be updated after this + run["finished"] = True + results = self.get_results(run, True) + run["results_info"] = format_results(results, run) + # De-couple the styling of the run from its finished status + if run["results_info"]["style"] == "#44EB44": + run["is_green"] = True + elif run["results_info"]["style"] == "yellow": + run["is_yellow"] = True + self.buffer(run, True) + # Publish the results of the run to the Fishcooking forum + post_in_fishcooking_results(run) + + def approve_run(self, run_id, approver): + run = self.get_run(run_id) + # Can't self approve + if run["args"]["username"] == approver: + return False + + run["approved"] = True + run["approver"] = approver self.buffer(run, True) - # Publish the results of the run to the Fishcooking forum - post_in_fishcooking_results(run) - - def approve_run(self, run_id, approver): - run = self.get_run(run_id) - # Can't self approve - if run['args']['username'] == approver: - return False - - run['approved'] = True - run['approver'] = approver - self.buffer(run, True) - self.task_time = 0 - return True - - def purge_run(self, run): - # Remove bad tasks - purged = False - chi2 = calculate_residuals(run) - if 'bad_tasks' not in run: - run['bad_tasks'] = [] - for task in run['tasks']: - if task['worker_key'] in chi2['bad_users']: - purged = True - task['bad'] = True - run['bad_tasks'].append(task) - run['tasks'].remove(task) - if purged: - # Generate new tasks if needed - run['results_stale'] = True - results = self.get_results(run) - played_games = results['wins'] + results['losses'] + results['draws'] - if played_games < run['args']['num_games']: - run['tasks'] += self.generate_tasks( - run['args']['num_games'] - played_games) - run['finished'] = False - if 'sprt' in run['args'] and 'state' in run['args']['sprt']: - fishtest.stats.stat_util.update_SPRT(results, run['args']['sprt']) - run['args']['sprt']['state'] = '' - self.buffer(run, True) - return purged - - def spsa_param_clip_round(self, param, increment, clipping, rounding): - if clipping == 'old': - value = param['theta'] + increment - if value < param['min']: - value = param['min'] - elif value > param['max']: - value = param['max'] - else: # clipping == 'careful': - inc = min(abs(increment), abs(param['theta'] - param['min']) / 2, - abs(param['theta'] - param['max']) / 2) - if inc > 0: - value = param['theta'] + inc * increment / abs(increment) - else: # revert to old behavior to bounce off boundary - value = param['theta'] + increment - if value < param['min']: - value = param['min'] - elif value > param['max']: - value = param['max'] - - # 'deterministic' rounding calls round() inside the worker. - # 'randomized' says 4.p should be 5 with probability p, - # 4 with probability 1-p, - # and is continuous (albeit after expectation) unlike round(). - if rounding == 'randomized': - value = math.floor(value + random.uniform(0, 1)) - - return value - - # Store SPSA parameters for each worker - spsa_params = {} - - def store_params(self, run_id, worker, params): - run_id = str(run_id) - if run_id not in self.spsa_params: - self.spsa_params[run_id] = {} - self.spsa_params[run_id][worker] = params - - def get_params(self, run_id, worker): - run_id = str(run_id) - if run_id not in self.spsa_params or worker not in self.spsa_params[run_id]: - # Should only happen after server restart - return self.generate_spsa(self.get_run(run_id))['w_params'] - return self.spsa_params[run_id][worker] - - def clear_params(self, run_id): - run_id = str(run_id) - if run_id in self.spsa_params: - del self.spsa_params[run_id] - - def request_spsa(self, run_id, task_id): - run = self.get_run(run_id) - - if task_id >= len(run['tasks']): - return {'task_alive': False} - task = run['tasks'][task_id] - if not task['active'] or not task['pending']: - return {'task_alive': False} - - result = self.generate_spsa(run) - self.store_params(run['_id'], task['worker_info']['unique_key'], - result['w_params']) - return result - - def generate_spsa(self, run): - result = { - 'task_alive': True, - 'w_params': [], - 'b_params': [], - } - spsa = run['args']['spsa'] - if 'clipping' not in spsa: - spsa['clipping'] = 'old' - if 'rounding' not in spsa: - spsa['rounding'] = 'deterministic' - - # Generate the next set of tuning parameters - iter_local = spsa['iter'] + 1 # assume at least one completed, - # and avoid division by zero - for param in spsa['params']: - c = param['c'] / iter_local ** spsa['gamma'] - flip = 1 if random.getrandbits(1) else -1 - result['w_params'].append({ - 'name': param['name'], - 'value': self.spsa_param_clip_round(param, c * flip, - spsa['clipping'], spsa['rounding']), - 'R': param['a'] / (spsa['A'] + iter_local) ** spsa['alpha'] / c ** 2, - 'c': c, - 'flip': flip, - }) - result['b_params'].append({ - 'name': param['name'], - 'value': self.spsa_param_clip_round(param, -c * flip, spsa['clipping'], spsa['rounding']), - }) - - return result - - def update_spsa(self, worker, run, spsa_results): - spsa = run['args']['spsa'] - if 'clipping' not in spsa: - spsa['clipping'] = 'old' - - spsa['iter'] += int(spsa_results['num_games'] / 2) - - # Store the history every 'freq' iterations. - # More tuned parameters result in a lower update frequency, - # so that the required storage (performance) remains constant. - if 'param_history' not in spsa: - spsa['param_history'] = [] - L = len(spsa['params']) - freq = L * 25 - if freq < 100: - freq = 100 - maxlen = 250000 / freq - grow_summary = len(spsa['param_history']) < min(maxlen, spsa['iter'] / freq) - - # Update the current theta based on the results from the worker - # Worker wins/losses are always in terms of w_params - result = spsa_results['wins'] - spsa_results['losses'] - summary = [] - w_params = self.get_params(run['_id'], worker) - for idx, param in enumerate(spsa['params']): - R = w_params[idx]['R'] - c = w_params[idx]['c'] - flip = w_params[idx]['flip'] - param['theta'] = self.spsa_param_clip_round(param, R * c * result * flip, - spsa['clipping'], - 'deterministic') - if grow_summary: - summary.append({ - 'theta': param['theta'], - 'R': R, - 'c': c, - }) - - if grow_summary: - spsa['param_history'].append(summary) + self.task_time = 0 + return True + + def purge_run(self, run): + # Remove bad tasks + purged = False + chi2 = calculate_residuals(run) + if "bad_tasks" not in run: + run["bad_tasks"] = [] + for task in run["tasks"]: + if task["worker_key"] in chi2["bad_users"]: + purged = True + task["bad"] = True + run["bad_tasks"].append(task) + run["tasks"].remove(task) + if purged: + # Generate new tasks if needed + run["results_stale"] = True + results = self.get_results(run) + played_games = results["wins"] + results["losses"] + results["draws"] + if played_games < run["args"]["num_games"]: + run["tasks"] += self.generate_tasks( + run["args"]["num_games"] - played_games + ) + run["finished"] = False + if "sprt" in run["args"] and "state" in run["args"]["sprt"]: + fishtest.stats.stat_util.update_SPRT(results, run["args"]["sprt"]) + run["args"]["sprt"]["state"] = "" + self.buffer(run, True) + return purged + + def spsa_param_clip_round(self, param, increment, clipping, rounding): + if clipping == "old": + value = param["theta"] + increment + if value < param["min"]: + value = param["min"] + elif value > param["max"]: + value = param["max"] + else: # clipping == 'careful': + inc = min( + abs(increment), + abs(param["theta"] - param["min"]) / 2, + abs(param["theta"] - param["max"]) / 2, + ) + if inc > 0: + value = param["theta"] + inc * increment / abs(increment) + else: # revert to old behavior to bounce off boundary + value = param["theta"] + increment + if value < param["min"]: + value = param["min"] + elif value > param["max"]: + value = param["max"] + + # 'deterministic' rounding calls round() inside the worker. + # 'randomized' says 4.p should be 5 with probability p, + # 4 with probability 1-p, + # and is continuous (albeit after expectation) unlike round(). + if rounding == "randomized": + value = math.floor(value + random.uniform(0, 1)) + + return value + + # Store SPSA parameters for each worker + spsa_params = {} + + def store_params(self, run_id, worker, params): + run_id = str(run_id) + if run_id not in self.spsa_params: + self.spsa_params[run_id] = {} + self.spsa_params[run_id][worker] = params + + def get_params(self, run_id, worker): + run_id = str(run_id) + if run_id not in self.spsa_params or worker not in self.spsa_params[run_id]: + # Should only happen after server restart + return self.generate_spsa(self.get_run(run_id))["w_params"] + return self.spsa_params[run_id][worker] + + def clear_params(self, run_id): + run_id = str(run_id) + if run_id in self.spsa_params: + del self.spsa_params[run_id] + + def request_spsa(self, run_id, task_id): + run = self.get_run(run_id) + + if task_id >= len(run["tasks"]): + return {"task_alive": False} + task = run["tasks"][task_id] + if not task["active"] or not task["pending"]: + return {"task_alive": False} + + result = self.generate_spsa(run) + self.store_params( + run["_id"], task["worker_info"]["unique_key"], result["w_params"] + ) + return result + + def generate_spsa(self, run): + result = {"task_alive": True, "w_params": [], "b_params": []} + spsa = run["args"]["spsa"] + if "clipping" not in spsa: + spsa["clipping"] = "old" + if "rounding" not in spsa: + spsa["rounding"] = "deterministic" + + # Generate the next set of tuning parameters + iter_local = spsa["iter"] + 1 # assume at least one completed, + # and avoid division by zero + for param in spsa["params"]: + c = param["c"] / iter_local ** spsa["gamma"] + flip = 1 if random.getrandbits(1) else -1 + result["w_params"].append( + { + "name": param["name"], + "value": self.spsa_param_clip_round( + param, c * flip, spsa["clipping"], spsa["rounding"] + ), + "R": param["a"] + / (spsa["A"] + iter_local) ** spsa["alpha"] + / c ** 2, + "c": c, + "flip": flip, + } + ) + result["b_params"].append( + { + "name": param["name"], + "value": self.spsa_param_clip_round( + param, -c * flip, spsa["clipping"], spsa["rounding"] + ), + } + ) + + return result + + def update_spsa(self, worker, run, spsa_results): + spsa = run["args"]["spsa"] + if "clipping" not in spsa: + spsa["clipping"] = "old" + + spsa["iter"] += int(spsa_results["num_games"] / 2) + + # Store the history every 'freq' iterations. + # More tuned parameters result in a lower update frequency, + # so that the required storage (performance) remains constant. + if "param_history" not in spsa: + spsa["param_history"] = [] + L = len(spsa["params"]) + freq = L * 25 + if freq < 100: + freq = 100 + maxlen = 250000 / freq + grow_summary = len(spsa["param_history"]) < min(maxlen, spsa["iter"] / freq) + + # Update the current theta based on the results from the worker + # Worker wins/losses are always in terms of w_params + result = spsa_results["wins"] - spsa_results["losses"] + summary = [] + w_params = self.get_params(run["_id"], worker) + for idx, param in enumerate(spsa["params"]): + R = w_params[idx]["R"] + c = w_params[idx]["c"] + flip = w_params[idx]["flip"] + param["theta"] = self.spsa_param_clip_round( + param, R * c * result * flip, spsa["clipping"], "deterministic" + ) + if grow_summary: + summary.append({"theta": param["theta"], "R": R, "c": c}) + + if grow_summary: + spsa["param_history"].append(summary) diff --git a/fishtest/fishtest/stats/LLRcalc.py b/fishtest/fishtest/stats/LLRcalc.py index 3d1bf5ab5..2f45cdac4 100644 --- a/fishtest/fishtest/stats/LLRcalc.py +++ b/fishtest/fishtest/stats/LLRcalc.py @@ -1,12 +1,13 @@ from __future__ import division -import math,sys,copy +import math, sys, copy import scipy import scipy.optimize -def MLE(pdf,s): + +def MLE(pdf, s): """ This function computes the maximum likelood estimate for a discrete distribution with expectation value s, @@ -23,58 +24,64 @@ def MLE(pdf,s): (see Proposition 1.1). """ - epsilon=1e-9 - v,w=pdf[0][0],pdf[-1][0] - assert(v15: - ret=self.outcome_cdf_alt2(T,y) + sigma2 = self.sigma2 + mu = self.mu + gamma = mu / sigma2 + A = self.b - self.a + if sigma2 * T / A ** 2 < 1e-2 or abs(gamma * A) > 15: + ret = self.outcome_cdf_alt2(T, y) else: - ret=self.outcome_cdf_alt1(T,y) - assert -1e-3 <= ret <= 1+1e-3 + ret = self.outcome_cdf_alt1(T, y) + assert -1e-3 <= ret <= 1 + 1e-3 return ret - - def outcome_cdf_alt1(self,T=None,y=None): + + def outcome_cdf_alt1(self, T=None, y=None): """ Computes the probability that the particle passes to the right of (T,y), the time axis being vertically oriented. This may give a numerical exception if math.pi**2*sigma2*T/(2*A**2) is small. """ - mu=self.mu - sigma2=self.sigma2 - A=self.b-self.a - x=0-self.a - y=y-self.a - gamma=mu/sigma2 - n=1 - s=0.0 - lambda_1=((math.pi/A)**2)*sigma2/2+(mu**2/sigma2)/2 - t0=math.exp(-lambda_1*T-x*gamma+y*gamma) + mu = self.mu + sigma2 = self.sigma2 + A = self.b - self.a + x = 0 - self.a + y = y - self.a + gamma = mu / sigma2 + n = 1 + s = 0.0 + lambda_1 = ((math.pi / A) ** 2) * sigma2 / 2 + (mu ** 2 / sigma2) / 2 + t0 = math.exp(-lambda_1 * T - x * gamma + y * gamma) while True: - lambda_n=((n*math.pi/A)**2)*sigma2/2+(mu**2/sigma2)/2 - t1=math.exp(-(lambda_n-lambda_1)*T) - t3=U(n,gamma,A,y) - t4=math.sin(n*math.pi*x/A) - s+=t1*t3*t4 - if abs(t0*t1*t3)<=1e-9: + lambda_n = ((n * math.pi / A) ** 2) * sigma2 / 2 + (mu ** 2 / sigma2) / 2 + t1 = math.exp(-(lambda_n - lambda_1) * T) + t3 = U(n, gamma, A, y) + t4 = math.sin(n * math.pi * x / A) + s += t1 * t3 * t4 + if abs(t0 * t1 * t3) <= 1e-9: break - n+=1 - if gamma*A>30: # avoid numerical overflow - pre=math.exp(-2*gamma*x) - elif abs(gamma*A)<1e-8: # avoid division by zero - pre=(A-x)/A + n += 1 + if gamma * A > 30: # avoid numerical overflow + pre = math.exp(-2 * gamma * x) + elif abs(gamma * A) < 1e-8: # avoid division by zero + pre = (A - x) / A else: - pre=(1-math.exp(2*gamma*(A-x)))/(1-math.exp(2*gamma*A)) - return pre+t0*s + pre = (1 - math.exp(2 * gamma * (A - x))) / (1 - math.exp(2 * gamma * A)) + return pre + t0 * s - def outcome_cdf_alt2(self,T=None,y=None): + def outcome_cdf_alt2(self, T=None, y=None): """ Siegmund's approximation. We use it as backup if our exact formula converges too slowly. To make the evaluation robust we use the asymptotic development of Phi. """ - denom=math.sqrt(T*self.sigma2) - offset=self.mu*T - gamma=self.mu/self.sigma2 - a=self.a - b=self.b - z=(y-offset)/denom - za=(-y+offset+2*a)/denom - zb=(y-offset-2*b)/denom - t1=Phi(z) - if gamma*a>=5: - t2=-math.exp(-za**2/2+2*gamma*a)/math.sqrt(2*math.pi)*(1/za-1/za**3) + denom = math.sqrt(T * self.sigma2) + offset = self.mu * T + gamma = self.mu / self.sigma2 + a = self.a + b = self.b + z = (y - offset) / denom + za = (-y + offset + 2 * a) / denom + zb = (y - offset - 2 * b) / denom + t1 = Phi(z) + if gamma * a >= 5: + t2 = ( + -math.exp(-za ** 2 / 2 + 2 * gamma * a) + / math.sqrt(2 * math.pi) + * (1 / za - 1 / za ** 3) + ) else: - t2=math.exp(2*gamma*a)*Phi(za) - if gamma*b>=5: - t3=-math.exp(-zb**2/2+2*gamma*b)/math.sqrt(2*math.pi)*(1/zb-1/zb**3) + t2 = math.exp(2 * gamma * a) * Phi(za) + if gamma * b >= 5: + t3 = ( + -math.exp(-zb ** 2 / 2 + 2 * gamma * b) + / math.sqrt(2 * math.pi) + * (1 / zb - 1 / zb ** 3) + ) else: - t3=math.exp(2*gamma*b)*Phi(zb) - return t1+t2-t3 - - - - + t3 = math.exp(2 * gamma * b) * Phi(zb) + return t1 + t2 - t3 diff --git a/fishtest/fishtest/stats/sprt.py b/fishtest/fishtest/stats/sprt.py index 557e83e29..9b315b802 100644 --- a/fishtest/fishtest/stats/sprt.py +++ b/fishtest/fishtest/stats/sprt.py @@ -1,6 +1,6 @@ from __future__ import division -import math,copy +import math, copy import argparse import scipy.optimize @@ -8,120 +8,146 @@ from fishtest.stats.brownian import Brownian from fishtest.stats import LLRcalc + class sprt: - def __init__(self,alpha=0.05,beta=0.05,elo0=0,elo1=5): - self.a=math.log(beta/(1-alpha)) - self.b=math.log((1-beta)/alpha) - self.elo0=elo0 - self.elo1=elo1 - self.s0=LLRcalc.L_(elo0) - self.s1=LLRcalc.L_(elo1) - self.clamped=False - self.LLR_drift_variance=LLRcalc.LLR_drift_variance_alt2 + def __init__(self, alpha=0.05, beta=0.05, elo0=0, elo1=5): + self.a = math.log(beta / (1 - alpha)) + self.b = math.log((1 - beta) / alpha) + self.elo0 = elo0 + self.elo1 = elo1 + self.s0 = LLRcalc.L_(elo0) + self.s1 = LLRcalc.L_(elo1) + self.clamped = False + self.LLR_drift_variance = LLRcalc.LLR_drift_variance_alt2 - def set_state(self,results): - N,self.pdf=LLRcalc.results_to_pdf(results) - mu_LLR,var_LLR=self.LLR_drift_variance(self.pdf,self.s0,self.s1,None) + def set_state(self, results): + N, self.pdf = LLRcalc.results_to_pdf(results) + mu_LLR, var_LLR = self.LLR_drift_variance(self.pdf, self.s0, self.s1, None) # llr estimate - self.llr=N*mu_LLR - self.T=N + self.llr = N * mu_LLR + self.T = N # now normalize llr (if llr is not legal then the implications # of this are unclear) - slope=self.llr/N - if self.llr>1.03*self.b or self.llr<1.03*self.a: - self.clamped=True - if self.llrself.b: - self.T=self.b/slope - self.llr=self.b + slope = self.llr / N + if self.llr > 1.03 * self.b or self.llr < 1.03 * self.a: + self.clamped = True + if self.llr < self.a: + self.T = self.a / slope + self.llr = self.a + elif self.llr > self.b: + self.T = self.b / slope + self.llr = self.b - def outcome_prob(self,elo): + def outcome_prob(self, elo): """ The probability of a test with the given elo with worse outcome (faster fail, slower pass or a pass changed into a fail). """ - s=LLRcalc.L_(elo) - mu_LLR,var_LLR=self.LLR_drift_variance(self.pdf,self.s0,self.s1,s) - sigma_LLR=math.sqrt(var_LLR) - return Brownian(a=self.a,b=self.b,mu=mu_LLR,sigma=sigma_LLR).outcome_cdf(T=self.T,y=self.llr) + s = LLRcalc.L_(elo) + mu_LLR, var_LLR = self.LLR_drift_variance(self.pdf, self.s0, self.s1, s) + sigma_LLR = math.sqrt(var_LLR) + return Brownian(a=self.a, b=self.b, mu=mu_LLR, sigma=sigma_LLR).outcome_cdf( + T=self.T, y=self.llr + ) - def lower_cb(self,p): + def lower_cb(self, p): """ Maximal elo value such that the observed outcome of the test has probability less than p. """ - avg_elo=(self.elo0+self.elo1)/2 - delta=self.elo1-self.elo0 - N=30 -# Various error conditions must be handled better here! + avg_elo = (self.elo0 + self.elo1) / 2 + delta = self.elo1 - self.elo0 + N = 30 + # Various error conditions must be handled better here! while True: - elo0=max(avg_elo-N*delta,-1000) - elo1=min(avg_elo+N*delta,1000) + elo0 = max(avg_elo - N * delta, -1000) + elo1 = min(avg_elo + N * delta, 1000) try: - sol,res=scipy.optimize.brentq(lambda elo:self.outcome_prob(elo)-(1-p), - elo0, - elo1, - full_output=True, - disp=False) + sol, res = scipy.optimize.brentq( + lambda elo: self.outcome_prob(elo) - (1 - p), + elo0, + elo1, + full_output=True, + disp=False, + ) except ValueError: - if elo0>-1000 or elo1<1000: - N*=2 + if elo0 > -1000 or elo1 < 1000: + N *= 2 continue else: - if self.outcome_prob(elo0)-(1-p)>0: + if self.outcome_prob(elo0) - (1 - p) > 0: return elo1 else: return elo0 - assert(res.converged) + assert res.converged break return sol - def analytics(self,p=0.05): - ret={} - ret['clamped']=self.clamped - ret['a']=self.a - ret['b']=self.b - ret['elo']=self.lower_cb(0.5) - ret['ci']=[self.lower_cb(p/2),self.lower_cb(1-p/2)] - ret['LOS']=self.outcome_prob(0) - ret['LLR']=self.llr + def analytics(self, p=0.05): + ret = {} + ret["clamped"] = self.clamped + ret["a"] = self.a + ret["b"] = self.b + ret["elo"] = self.lower_cb(0.5) + ret["ci"] = [self.lower_cb(p / 2), self.lower_cb(1 - p / 2)] + ret["LOS"] = self.outcome_prob(0) + ret["LLR"] = self.llr return ret -if __name__=='__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--alpha",help="probability of a false positve",type=float,default=0.05) - parser.add_argument("--beta" ,help="probability of a false negative",type=float,default=0.05) - parser.add_argument("--elo0", help="H0 (expressed in LogisticElo)",type=float,default=0.0) - parser.add_argument("--elo1", help="H1 (expressed in LogisticElo)",type=float,default=5.0) - parser.add_argument("--level",help="confidence level",type=float,default=0.95) - parser.add_argument("--results", help="trinomial of pentanomial frequencies, low to high",nargs="*",type=int, required=True) - args=parser.parse_args() - results=args.results - if len(results)!=3 and len(results)!=5: + parser.add_argument( + "--alpha", help="probability of a false positve", type=float, default=0.05 + ) + parser.add_argument( + "--beta", help="probability of a false negative", type=float, default=0.05 + ) + parser.add_argument( + "--elo0", help="H0 (expressed in LogisticElo)", type=float, default=0.0 + ) + parser.add_argument( + "--elo1", help="H1 (expressed in LogisticElo)", type=float, default=5.0 + ) + parser.add_argument("--level", help="confidence level", type=float, default=0.95) + parser.add_argument( + "--results", + help="trinomial of pentanomial frequencies, low to high", + nargs="*", + type=int, + required=True, + ) + args = parser.parse_args() + results = args.results + if len(results) != 3 and len(results) != 5: parser.error("argument --results: expected 3 or 5 arguments") - alpha=args.alpha - beta=args.beta - elo0=args.elo0 - elo1=args.elo1 - p=1-args.level - s=sprt(alpha=alpha,beta=beta,elo0=elo0,elo1=elo1) + alpha = args.alpha + beta = args.beta + elo0 = args.elo0 + elo1 = args.elo1 + p = 1 - args.level + s = sprt(alpha=alpha, beta=beta, elo0=elo0, elo1=elo1) s.set_state(results) - a=s.analytics(p) + a = s.analytics(p) print("Design parameters") print("=================") - print("False positives : %4.2f%%" % (100*alpha,)) - print("False negatives : %4.2f%%" % (100*beta,) ) - print("[Elo0,Elo1] : [%.2f,%.2f]" % (elo0,elo1)) - print("Confidence level : %4.2f%%" % (100*(1-p),)) + print("False positives : %4.2f%%" % (100 * alpha,)) + print("False negatives : %4.2f%%" % (100 * beta,)) + print("[Elo0,Elo1] : [%.2f,%.2f]" % (elo0, elo1)) + print("Confidence level : %4.2f%%" % (100 * (1 - p),)) print("Estimates") print("=========") - print("Elo : %.2f" % a['elo']) - print("Confidence interval : [%.2f,%.2f] (%4.2f%%)" % (a['ci'][0],a['ci'][1],100*(1-p))) - print("LOS : %4.2f%%" % (100*a['LOS'],)) + print("Elo : %.2f" % a["elo"]) + print( + "Confidence interval : [%.2f,%.2f] (%4.2f%%)" + % (a["ci"][0], a["ci"][1], 100 * (1 - p)) + ) + print("LOS : %4.2f%%" % (100 * a["LOS"],)) print("Context") print("=======") - print("LLR [u,l] : %.2f %s [%.2f,%.2f]" % (a['LLR'], '(clamped)' if a['clamped'] else '',a['a'],a['b'])) + print( + "LLR [u,l] : %.2f %s [%.2f,%.2f]" + % (a["LLR"], "(clamped)" if a["clamped"] else "", a["a"], a["b"]) + ) diff --git a/fishtest/fishtest/stats/stat_util.py b/fishtest/fishtest/stats/stat_util.py index 24cbf17c1..e081eeeae 100644 --- a/fishtest/fishtest/stats/stat_util.py +++ b/fishtest/fishtest/stats/stat_util.py @@ -1,6 +1,6 @@ from __future__ import division -import math,copy +import math, copy import scipy.stats @@ -8,183 +8,201 @@ from fishtest.stats import sprt from fishtest.stats import brownian + def Phi(q): - """ + """ Cumlative distribution function for the standard Gaussian law: quantile -> probability """ - return scipy.stats.norm.cdf(q) + return scipy.stats.norm.cdf(q) + def Phi_inv(p): - """ + """ Quantile function for the standard Gaussian law: probability -> quantile """ - return scipy.stats.norm.ppf(p) + return scipy.stats.norm.ppf(p) + def elo(x): - epsilon=1e-3 - x=max(x,epsilon) - x=min(x,1-epsilon) - return -400*math.log10(1/x-1) + epsilon = 1e-3 + x = max(x, epsilon) + x = min(x, 1 - epsilon) + return -400 * math.log10(1 / x - 1) + def L(x): - return 1/(1+10**(-x/400.0)) + return 1 / (1 + 10 ** (-x / 400.0)) + def stats(results): - """ + """ "results" is an array of length 2*n+1 with aggregated frequences for n games. """ - l=len(results) - N=sum(results) - games=N*(l-1)/2.0 + l = len(results) + N = sum(results) + games = N * (l - 1) / 2.0 + + # empirical expected score for a single game + mu = sum([results[i] * (i / 2.0) for i in range(0, l)]) / games -# empirical expected score for a single game - mu=sum([results[i]*(i/2.0) for i in range(0,l)])/games + # empirical expected variance for a single game + mu_ = (l - 1) / 2.0 * mu + var = sum([results[i] * (i / 2.0 - mu_) ** 2.0 for i in range(0, l)]) / games -# empirical expected variance for a single game - mu_=(l-1)/2.0*mu - var=sum([results[i]*(i/2.0-mu_)**2.0 for i in range(0,l)])/games + return games, mu, var - return games,mu,var def get_elo(results): - """ + """ "results" is an array of length 2*n+1 with aggregated frequences for n games. """ - results=LLRcalc.regularize(results) - games,mu,var=stats(results) - stdev = math.sqrt(var) + results = LLRcalc.regularize(results) + games, mu, var = stats(results) + stdev = math.sqrt(var) -# 95% confidence interval for mu - mu_min=mu+Phi_inv(0.025)*stdev/math.sqrt(games) - mu_max=mu+Phi_inv(0.975)*stdev/math.sqrt(games) + # 95% confidence interval for mu + mu_min = mu + Phi_inv(0.025) * stdev / math.sqrt(games) + mu_max = mu + Phi_inv(0.975) * stdev / math.sqrt(games) - el=elo(mu) - elo95=(elo(mu_max)-elo(mu_min))/2.0 - los = Phi((mu-0.5)/(stdev/math.sqrt(games))) + el = elo(mu) + elo95 = (elo(mu_max) - elo(mu_min)) / 2.0 + los = Phi((mu - 0.5) / (stdev / math.sqrt(games))) - return el,elo95,los + return el, elo95, los def bayeselo_to_proba(elo, drawelo): - """ + """ elo is expressed in BayesELO (relative to the choice drawelo). Returns a probability, P[2], P[0], P[1] (win,loss,draw). """ - P = 3*[0] - P[2] = 1.0 / (1.0 + pow(10.0, (-elo + drawelo) / 400.0)) - P[0] = 1.0 / (1.0 + pow(10.0, (elo + drawelo) / 400.0)) - P[1] = 1.0 - P[2] - P[0] - return P + P = 3 * [0] + P[2] = 1.0 / (1.0 + pow(10.0, (-elo + drawelo) / 400.0)) + P[0] = 1.0 / (1.0 + pow(10.0, (elo + drawelo) / 400.0)) + P[1] = 1.0 - P[2] - P[0] + return P + def proba_to_bayeselo(P): - """ + """ Takes a probability: P[2], P[0] Returns elo, drawelo. """ - assert(0 < P[2] and P[2] < 1 and 0 < P[0] and P[0] < 1) - elo = 200 * math.log10(P[2]/P[0] * (1-P[0])/(1-P[2])) - drawelo = 200 * math.log10((1-P[0])/P[0] * (1-P[2])/P[2]) - return elo, drawelo + assert 0 < P[2] and P[2] < 1 and 0 < P[0] and P[0] < 1 + elo = 200 * math.log10(P[2] / P[0] * (1 - P[0]) / (1 - P[2])) + drawelo = 200 * math.log10((1 - P[0]) / P[0] * (1 - P[2]) / P[2]) + return elo, drawelo + def draw_elo_calc(R): - """ + """ Takes trinomial frequences R[0],R[1],R[2] (loss,draw,win) and returns the corresponding drawelo value. """ - N=sum(R) - P=[p/N for p in R] - _, drawelo = proba_to_bayeselo(P) - return drawelo + N = sum(R) + P = [p / N for p in R] + _, drawelo = proba_to_bayeselo(P) + return drawelo + def bayeselo_to_elo(belo, drawelo): - P = bayeselo_to_proba(belo, drawelo) - return elo(P[2]+0.5*P[1]) + P = bayeselo_to_proba(belo, drawelo) + return elo(P[2] + 0.5 * P[1]) + def elo_to_bayeselo(elo, draw_ratio): - assert(draw_ratio>=0) - s=L(elo) - P=3*[0] - P[2]=s-draw_ratio/2.0 - P[1]=draw_ratio - P[0]=1-P[1]-P[2] - if P[0]<=0 or P[2]<=0: - return float('NaN'),float('NaN') - return proba_to_bayeselo(P) + assert draw_ratio >= 0 + s = L(elo) + P = 3 * [0] + P[2] = s - draw_ratio / 2.0 + P[1] = draw_ratio + P[0] = 1 - P[1] - P[2] + if P[0] <= 0 or P[2] <= 0: + return float("NaN"), float("NaN") + return proba_to_bayeselo(P) + def SPRT_elo(R, alpha=0.05, beta=0.05, p=0.05, elo0=None, elo1=None, elo_model=None): - """ + """ Calculate an elo estimate from an SPRT test. """ - assert(elo_model in ['BayesElo','logistic']) - - # Estimate drawelo out of sample - R3=LLRcalc.regularize([R['losses'],R['draws'],R['wins']]) - drawelo=draw_elo_calc(R3) - - # Convert the bounds to logistic elo if necessary - if elo_model=='BayesElo': - lelo0,lelo1=[bayeselo_to_elo(elo_, drawelo) for elo_ in (elo0,elo1)] - else: - lelo0,lelo1=elo0,elo1 - - # Make the elo estimation object - sp=sprt.sprt(alpha=alpha,beta=beta,elo0=lelo0,elo1=lelo1) - - # Feed the results - if 'pentanomial' in R.keys(): - R_=R['pentanomial'] - else: - R_=R3 - sp.set_state(R_) - - # Get the elo estimates - a=sp.analytics(p) - - # Override the LLR approximation with the exact one - a['LLR']=LLRcalc.LLR_logistic(lelo0,lelo1,R_) - del a['clamped'] - # Now return the estimates - return a - -def LLRlegacy(belo0,belo1,results): - """ + assert elo_model in ["BayesElo", "logistic"] + + # Estimate drawelo out of sample + R3 = LLRcalc.regularize([R["losses"], R["draws"], R["wins"]]) + drawelo = draw_elo_calc(R3) + + # Convert the bounds to logistic elo if necessary + if elo_model == "BayesElo": + lelo0, lelo1 = [bayeselo_to_elo(elo_, drawelo) for elo_ in (elo0, elo1)] + else: + lelo0, lelo1 = elo0, elo1 + + # Make the elo estimation object + sp = sprt.sprt(alpha=alpha, beta=beta, elo0=lelo0, elo1=lelo1) + + # Feed the results + if "pentanomial" in R.keys(): + R_ = R["pentanomial"] + else: + R_ = R3 + sp.set_state(R_) + + # Get the elo estimates + a = sp.analytics(p) + + # Override the LLR approximation with the exact one + a["LLR"] = LLRcalc.LLR_logistic(lelo0, lelo1, R_) + del a["clamped"] + # Now return the estimates + return a + + +def LLRlegacy(belo0, belo1, results): + """ LLR calculation using the BayesElo model where drawelo is estimated "out of sample". """ - assert(len(results)==3) - drawelo=draw_elo_calc(results) - P0=bayeselo_to_proba(belo0,drawelo) - P1=bayeselo_to_proba(belo1,drawelo) - return sum([results[i]*math.log(P1[i]/P0[i]) for i in range(0,3)]) - - -def SPRT(alpha=0.05,beta=0.05,elo0=None,elo1=None,elo_model='logistic',batch_size=1): - """ Constructor for the "sprt object" """ - return {'alpha' : alpha, - 'beta' : beta, - 'elo0' : elo0, - 'elo1' : elo1, - 'elo_model' : elo_model, - 'state' : '', - 'llr' : 0, - 'batch_size' : batch_size, - 'lower_bound' : math.log(beta/(1-alpha)), - 'upper_bound' : math.log((1-beta)/alpha), - 'overshoot' : {'last_update' : 0, - 'skipped_updates': 0, - 'ref0' : 0, - 'm0' : 0, - 'sq0' : 0, - 'ref1' : 0, - 'm1' : 0, - 'sq1' : 0} - } + assert len(results) == 3 + drawelo = draw_elo_calc(results) + P0 = bayeselo_to_proba(belo0, drawelo) + P1 = bayeselo_to_proba(belo1, drawelo) + return sum([results[i] * math.log(P1[i] / P0[i]) for i in range(0, 3)]) + + +def SPRT( + alpha=0.05, beta=0.05, elo0=None, elo1=None, elo_model="logistic", batch_size=1 +): + """ Constructor for the "sprt object" """ + return { + "alpha": alpha, + "beta": beta, + "elo0": elo0, + "elo1": elo1, + "elo_model": elo_model, + "state": "", + "llr": 0, + "batch_size": batch_size, + "lower_bound": math.log(beta / (1 - alpha)), + "upper_bound": math.log((1 - beta) / alpha), + "overshoot": { + "last_update": 0, + "skipped_updates": 0, + "ref0": 0, + "m0": 0, + "sq0": 0, + "ref1": 0, + "m1": 0, + "sq1": 0, + }, + } + def update_SPRT(R, sprt): - """Sequential Probability Ratio Test + """Sequential Probability Ratio Test sprt is a dictionary with fixed fields @@ -222,103 +240,200 @@ def update_SPRT(R, sprt): elo_model can be either 'BayesElo' or 'logistic' """ - # the next two lines are superfluous, but unfortunately necessary for backward - # compatibility with old tests - sprt['lower_bound']=math.log(sprt['beta']/(1-sprt['alpha'])) - sprt['upper_bound']=math.log((1-sprt['beta'])/sprt['alpha']) - - elo_model=sprt.get('elo_model', 'BayesElo') - assert(elo_model in ['BayesElo','logistic']) - elo0=sprt['elo0'] - elo1=sprt['elo1'] - - # first deal with the legacy BayesElo/trinomial models - R3=[R['losses'],R['draws'],R['wins']] - if elo_model=='BayesElo': - # estimate drawelo out of sample - R3_=LLRcalc.regularize(R3) - drawelo=draw_elo_calc(R3_) - # conversion of bounds to logistic elo - lelo0,lelo1=[bayeselo_to_elo(elo,drawelo) for elo in (elo0,elo1)] - else: - lelo0,lelo1=elo0,elo1 - - R_=R.get('pentanomial',R3) - - batch_size=sprt.get('batch_size',1) - - # sanity check on batch_size - if sum(R_)%batch_size!=0: - sprt['illegal_update']=sum(R_) # audit - if 'overshoot' in sprt: - del sprt['overshoot'] # the contract is violated - - # Log-Likelihood Ratio - sprt['llr']=LLRcalc.LLR_logistic(lelo0,lelo1,R_) - - # update the overshoot data - if 'overshoot' in sprt: - LLR_=sprt['llr'] - o=sprt['overshoot'] - num_samples=sum(R_) - if num_samples < o['last_update']: # purge? - sprt['lost_samples']=o['last_update']-num_samples # audit - del sprt['overshoot'] # the contract is violated + # the next two lines are superfluous, but unfortunately necessary for backward + # compatibility with old tests + sprt["lower_bound"] = math.log(sprt["beta"] / (1 - sprt["alpha"])) + sprt["upper_bound"] = math.log((1 - sprt["beta"]) / sprt["alpha"]) + + elo_model = sprt.get("elo_model", "BayesElo") + assert elo_model in ["BayesElo", "logistic"] + elo0 = sprt["elo0"] + elo1 = sprt["elo1"] + + # first deal with the legacy BayesElo/trinomial models + R3 = [R["losses"], R["draws"], R["wins"]] + if elo_model == "BayesElo": + # estimate drawelo out of sample + R3_ = LLRcalc.regularize(R3) + drawelo = draw_elo_calc(R3_) + # conversion of bounds to logistic elo + lelo0, lelo1 = [bayeselo_to_elo(elo, drawelo) for elo in (elo0, elo1)] else: - if num_samples==o['last_update']: # same data - pass - elif num_samples==o['last_update']+batch_size: # the normal case - if LLR_o['ref1']: - delta=LLR_-o['ref1'] - o['m1']+=delta - o['sq1']+=delta**2 - o['ref1']=LLR_ - else: - # Be robust if some updates are lost: reset data collection. - # This should not be needed anymore, but just in case... - o['ref0']=LLR_ - o['ref1']=LLR_ - o['skipped_updates']+=(num_samples-o['last_update'])-1 # audit - o['last_update']=num_samples - - o0=0 - o1=0 - if 'overshoot' in sprt: - o=sprt['overshoot'] - o0=-o['sq0']/o['m0']/2 if o['m0']!=0 else 0 - o1=o['sq1']/o['m1']/2 if o['m1']!=0 else 0 - - # now check the stop condition - sprt['state']='' - if sprt['llr'] < sprt['lower_bound']+o0: - sprt['state'] = 'rejected' - elif sprt['llr'] > sprt['upper_bound']-o1: - sprt['state'] = 'accepted' + lelo0, lelo1 = elo0, elo1 + + R_ = R.get("pentanomial", R3) + + batch_size = sprt.get("batch_size", 1) + + # sanity check on batch_size + if sum(R_) % batch_size != 0: + sprt["illegal_update"] = sum(R_) # audit + if "overshoot" in sprt: + del sprt["overshoot"] # the contract is violated + + # Log-Likelihood Ratio + sprt["llr"] = LLRcalc.LLR_logistic(lelo0, lelo1, R_) + + # update the overshoot data + if "overshoot" in sprt: + LLR_ = sprt["llr"] + o = sprt["overshoot"] + num_samples = sum(R_) + if num_samples < o["last_update"]: # purge? + sprt["lost_samples"] = o["last_update"] - num_samples # audit + del sprt["overshoot"] # the contract is violated + else: + if num_samples == o["last_update"]: # same data + pass + elif num_samples == o["last_update"] + batch_size: # the normal case + if LLR_ < o["ref0"]: + delta = LLR_ - o["ref0"] + o["m0"] += delta + o["sq0"] += delta ** 2 + o["ref0"] = LLR_ + if LLR_ > o["ref1"]: + delta = LLR_ - o["ref1"] + o["m1"] += delta + o["sq1"] += delta ** 2 + o["ref1"] = LLR_ + else: + # Be robust if some updates are lost: reset data collection. + # This should not be needed anymore, but just in case... + o["ref0"] = LLR_ + o["ref1"] = LLR_ + o["skipped_updates"] += (num_samples - o["last_update"]) - 1 # audit + o["last_update"] = num_samples + + o0 = 0 + o1 = 0 + if "overshoot" in sprt: + o = sprt["overshoot"] + o0 = -o["sq0"] / o["m0"] / 2 if o["m0"] != 0 else 0 + o1 = o["sq1"] / o["m1"] / 2 if o["m1"] != 0 else 0 + + # now check the stop condition + sprt["state"] = "" + if sprt["llr"] < sprt["lower_bound"] + o0: + sprt["state"] = "rejected" + elif sprt["llr"] > sprt["upper_bound"] - o1: + sprt["state"] = "accepted" + if __name__ == "__main__": - # unit tests - print('SPRT tests') - R={'wins': 65388,'losses': 65804, 'draws': 56553, 'pentanomial':[10789, 19328, 33806, 19402, 10543]} - sprt_=SPRT(elo0=-3, alpha=0.05, elo1=1, beta=0.05, elo_model='logistic') - update_SPRT(R,sprt_) - print(sprt_) - - print('elo tests') - print(SPRT_elo({'wins': 0, 'losses': 0, 'draws': 0}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 10, 'losses': 0, 'draws': 0}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 100, 'losses': 0, 'draws': 0}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 10, 'losses': 0, 'draws': 20}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 10, 'losses': 1, 'draws': 20}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 5019, 'losses': 5026, 'draws': 15699}, elo0=0, elo1=5, elo_model='BayesElo')) - print(SPRT_elo({'wins': 1450, 'losses': 1500, 'draws': 4000}, elo0=0, elo1=6, elo_model='BayesElo')) - print(SPRT_elo({'wins': 716, 'losses': 591, 'draws': 2163}, elo0=0, elo1=6, elo_model='BayesElo')) - print(SPRT_elo({'wins': 13543,'losses': 13624, 'draws': 34333}, elo0=-3, elo1=1, elo_model='BayesElo')) - print(SPRT_elo({'wins': 13543,'losses': 13624, 'draws': 34333, 'pentanomial':[1187, 7410, 13475, 7378, 1164]}, elo0=-3, elo1=1, elo_model='BayesElo')) - print(SPRT_elo({'wins': 65388,'losses': 65804, 'draws': 56553}, elo0=-3, elo1=1, elo_model='BayesElo')) - print(SPRT_elo({'wins': 65388,'losses': 65804, 'draws': 56553, 'pentanomial':[10789, 19328, 33806, 19402, 10543]}, elo0=-3, elo1=1, elo_model='BayesElo')) - print(SPRT_elo({'wins': 65388,'losses': 65804, 'draws': 56553, 'pentanomial':[10789, 19328, 33806, 19402, 10543]}, elo0=-3, elo1=1, elo_model='logistic')) + # unit tests + print("SPRT tests") + R = { + "wins": 65388, + "losses": 65804, + "draws": 56553, + "pentanomial": [10789, 19328, 33806, 19402, 10543], + } + sprt_ = SPRT(elo0=-3, alpha=0.05, elo1=1, beta=0.05, elo_model="logistic") + update_SPRT(R, sprt_) + print(sprt_) + + print("elo tests") + print( + SPRT_elo( + {"wins": 0, "losses": 0, "draws": 0}, elo0=0, elo1=5, elo_model="BayesElo" + ) + ) + print( + SPRT_elo( + {"wins": 10, "losses": 0, "draws": 0}, elo0=0, elo1=5, elo_model="BayesElo" + ) + ) + print( + SPRT_elo( + {"wins": 100, "losses": 0, "draws": 0}, elo0=0, elo1=5, elo_model="BayesElo" + ) + ) + print( + SPRT_elo( + {"wins": 10, "losses": 0, "draws": 20}, elo0=0, elo1=5, elo_model="BayesElo" + ) + ) + print( + SPRT_elo( + {"wins": 10, "losses": 1, "draws": 20}, elo0=0, elo1=5, elo_model="BayesElo" + ) + ) + print( + SPRT_elo( + {"wins": 5019, "losses": 5026, "draws": 15699}, + elo0=0, + elo1=5, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + {"wins": 1450, "losses": 1500, "draws": 4000}, + elo0=0, + elo1=6, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + {"wins": 716, "losses": 591, "draws": 2163}, + elo0=0, + elo1=6, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + {"wins": 13543, "losses": 13624, "draws": 34333}, + elo0=-3, + elo1=1, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + { + "wins": 13543, + "losses": 13624, + "draws": 34333, + "pentanomial": [1187, 7410, 13475, 7378, 1164], + }, + elo0=-3, + elo1=1, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + {"wins": 65388, "losses": 65804, "draws": 56553}, + elo0=-3, + elo1=1, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + { + "wins": 65388, + "losses": 65804, + "draws": 56553, + "pentanomial": [10789, 19328, 33806, 19402, 10543], + }, + elo0=-3, + elo1=1, + elo_model="BayesElo", + ) + ) + print( + SPRT_elo( + { + "wins": 65388, + "losses": 65804, + "draws": 56553, + "pentanomial": [10789, 19328, 33806, 19402, 10543], + }, + elo0=-3, + elo1=1, + elo_model="logistic", + ) + ) diff --git a/fishtest/fishtest/userdb.py b/fishtest/fishtest/userdb.py index 35be45de1..b5ab70a90 100644 --- a/fishtest/fishtest/userdb.py +++ b/fishtest/fishtest/userdb.py @@ -7,99 +7,102 @@ class UserDb: - def __init__(self, db): - self.db = db - self.users = self.db['users'] - self.user_cache = self.db['user_cache'] - self.top_month = self.db['top_month'] - self.flag_cache = self.db['flag_cache'] - - # Cache user lookups for 60s - user_lock = threading.Lock() - cache = {} - - def find(self, name): - with self.user_lock: - if name in self.cache: - u = self.cache[name] - if u['time'] > time.time() - 60: - return u['user'] - user = self.users.find_one({'username': name}) - if not user: - return None - self.cache[name] = {'user': user, 'time': time.time()} - return user - - def clear_cache(self): - with self.user_lock: - self.cache.clear() - - def authenticate(self, username, password): - user = self.find(username) - if not user or user['password'] != password: - sys.stderr.write('Invalid login: "%s" "%s"\n' % (username, password)) - return {'error': 'Invalid password'} - if 'blocked' in user and user['blocked']: - sys.stderr.write('Blocked login: "%s" "%s"\n' % (username, password)) - return {'error': 'Blocked'} - - return {'username': username, 'authenticated': True} - - def get_users(self): - return self.users.find(sort=[('_id', ASCENDING)]) - - # Cache pending for 1s - last_pending_time = 0 - last_pending = None - pending_lock = threading.Lock() - - def get_pending(self): - with self.pending_lock: - if time.time() > self.last_pending_time + 1: - self.last_pending = list(self.users.find({'blocked': True}, - sort=[('_id', ASCENDING)])) - self.last_pending_time = time.time() - return self.last_pending - - def get_user(self, username): - return self.find(username) - - def get_user_groups(self, username): - user = self.find(username) - if user: - groups = user['groups'] - return groups - - def add_user_group(self, username, group): - user = self.find(username) - user['groups'].append(group) - self.users.save(user) - - def create_user(self, username, password, email): - try: - if self.find(username): - return False - self.users.insert_one({ - 'username': username, - 'password': password, - 'registration_time': datetime.utcnow(), - 'blocked': True, - 'email': email, - 'groups': [], - 'tests_repo': '' - }) - self.last_pending_time = 0 - - return True - except: - return False - - def save_user(self, user): - self.users.replace_one({ 'username': user['username'] }, user) - self.last_pending_time = 0 - - def get_machine_limit(self, username): - user = self.find(username) - if user and 'machine_limit' in user: - return user['machine_limit'] - return 16 + def __init__(self, db): + self.db = db + self.users = self.db["users"] + self.user_cache = self.db["user_cache"] + self.top_month = self.db["top_month"] + self.flag_cache = self.db["flag_cache"] + + # Cache user lookups for 60s + user_lock = threading.Lock() + cache = {} + + def find(self, name): + with self.user_lock: + if name in self.cache: + u = self.cache[name] + if u["time"] > time.time() - 60: + return u["user"] + user = self.users.find_one({"username": name}) + if not user: + return None + self.cache[name] = {"user": user, "time": time.time()} + return user + + def clear_cache(self): + with self.user_lock: + self.cache.clear() + + def authenticate(self, username, password): + user = self.find(username) + if not user or user["password"] != password: + sys.stderr.write('Invalid login: "%s" "%s"\n' % (username, password)) + return {"error": "Invalid password"} + if "blocked" in user and user["blocked"]: + sys.stderr.write('Blocked login: "%s" "%s"\n' % (username, password)) + return {"error": "Blocked"} + + return {"username": username, "authenticated": True} + + def get_users(self): + return self.users.find(sort=[("_id", ASCENDING)]) + + # Cache pending for 1s + last_pending_time = 0 + last_pending = None + pending_lock = threading.Lock() + + def get_pending(self): + with self.pending_lock: + if time.time() > self.last_pending_time + 1: + self.last_pending = list( + self.users.find({"blocked": True}, sort=[("_id", ASCENDING)]) + ) + self.last_pending_time = time.time() + return self.last_pending + + def get_user(self, username): + return self.find(username) + + def get_user_groups(self, username): + user = self.find(username) + if user: + groups = user["groups"] + return groups + + def add_user_group(self, username, group): + user = self.find(username) + user["groups"].append(group) + self.users.save(user) + + def create_user(self, username, password, email): + try: + if self.find(username): + return False + self.users.insert_one( + { + "username": username, + "password": password, + "registration_time": datetime.utcnow(), + "blocked": True, + "email": email, + "groups": [], + "tests_repo": "", + } + ) + self.last_pending_time = 0 + + return True + except: + return False + + def save_user(self, user): + self.users.replace_one({"username": user["username"]}, user) + self.last_pending_time = 0 + + def get_machine_limit(self, username): + user = self.find(username) + if user and "machine_limit" in user: + return user["machine_limit"] + return 16 diff --git a/fishtest/fishtest/util.py b/fishtest/fishtest/util.py index 055acb67e..5e9e1144b 100644 --- a/fishtest/fishtest/util.py +++ b/fishtest/fishtest/util.py @@ -12,320 +12,328 @@ UUID_MAP = defaultdict(dict) key_lock = threading.Lock() -FISH_URL = 'https://tests.stockfishchess.org/tests/view/' +FISH_URL = "https://tests.stockfishchess.org/tests/view/" def get_worker_key(task): - global UUID_MAP + global UUID_MAP - if 'worker_info' not in task: - return '-' - username = task['worker_info'].get('username', '') - cores = str(task['worker_info']['concurrency']) + if "worker_info" not in task: + return "-" + username = task["worker_info"].get("username", "") + cores = str(task["worker_info"]["concurrency"]) - uuid = task['worker_info'].get('unique_key', '') - with key_lock: - if uuid not in UUID_MAP[username]: - next_idx = len(UUID_MAP[username]) - UUID_MAP[username][uuid] = next_idx + uuid = task["worker_info"].get("unique_key", "") + with key_lock: + if uuid not in UUID_MAP[username]: + next_idx = len(UUID_MAP[username]) + UUID_MAP[username][uuid] = next_idx - worker_key = '%s-%scores' % (username, cores) - suffix = UUID_MAP[username][uuid] - if suffix != 0: - worker_key += "-" + str(suffix) + worker_key = "%s-%scores" % (username, cores) + suffix = UUID_MAP[username][uuid] + if suffix != 0: + worker_key += "-" + str(suffix) - return worker_key + return worker_key def get_chi2(tasks, bad_users): - """ Perform chi^2 test on the stats from each worker """ - results = {'chi2': 0.0, 'dof': 0, 'p': 0.0, 'residual': {}} - - # Aggregate results by worker - users = {} - for task in tasks: - task['worker_key'] = get_worker_key(task) - if 'worker_info' not in task: - continue - key = get_worker_key(task) - if key in bad_users: - continue - stats = task.get('stats', {}) - wld = [float(stats.get('wins', 0)), - float(stats.get('losses', 0)), - float(stats.get('draws', 0))] - if wld == [0.0, 0.0, 0.0]: - continue - if key in users: - for idx in range(len(wld)): - users[key][idx] += wld[idx] - else: - users[key] = wld - - if len(users) == 0: - return results - - observed = numpy.array(list(users.values())) - rows, columns = observed.shape - # Results only from one worker: skip the test for workers homogeneity - if rows == 1: + """ Perform chi^2 test on the stats from each worker """ + results = {"chi2": 0.0, "dof": 0, "p": 0.0, "residual": {}} + + # Aggregate results by worker + users = {} + for task in tasks: + task["worker_key"] = get_worker_key(task) + if "worker_info" not in task: + continue + key = get_worker_key(task) + if key in bad_users: + continue + stats = task.get("stats", {}) + wld = [ + float(stats.get("wins", 0)), + float(stats.get("losses", 0)), + float(stats.get("draws", 0)), + ] + if wld == [0.0, 0.0, 0.0]: + continue + if key in users: + for idx in range(len(wld)): + users[key][idx] += wld[idx] + else: + users[key] = wld + + if len(users) == 0: + return results + + observed = numpy.array(list(users.values())) + rows, columns = observed.shape + # Results only from one worker: skip the test for workers homogeneity + if rows == 1: + return {"chi2": float("nan"), "dof": 0, "p": float("nan"), "residual": {}} + column_sums = numpy.sum(observed, axis=0) + columns_not_zero = sum(i > 0 for i in column_sums) + df = (rows - 1) * (columns - 1) + + if columns_not_zero == 0: + return results + # Results only of one type: workers are identical wrt the test + elif columns_not_zero == 1: + results = {"chi2": 0.0, "dof": df, "p": 1.0, "residual": {}} + return results + # Results only of two types: workers are identical wrt the missing result type + # Change the data shape to avoid divide by zero + elif columns_not_zero == 2: + idx = numpy.argwhere(numpy.all(observed[..., :] == 0, axis=0)) + observed = numpy.delete(observed, idx, axis=1) + column_sums = numpy.sum(observed, axis=0) + + row_sums = numpy.sum(observed, axis=1) + grand_total = numpy.sum(column_sums) + + expected = numpy.outer(row_sums, column_sums) / grand_total + raw_residual = observed - expected + std_error = numpy.sqrt( + expected + * numpy.outer((1 - row_sums / grand_total), (1 - column_sums / grand_total)) + ) + adj_residual = raw_residual / std_error + for idx in range(len(users)): + users[list(users.keys())[idx]] = numpy.max(numpy.abs(adj_residual[idx])) + chi2 = numpy.sum(raw_residual * raw_residual / expected) return { - 'chi2': float('nan'), - 'dof': 0, - 'p': float('nan'), - 'residual': {} + "chi2": chi2, + "dof": df, + "p": 1 - scipy.stats.chi2.cdf(chi2, df), + "residual": users, } - column_sums = numpy.sum(observed, axis=0) - columns_not_zero = sum(i > 0 for i in column_sums) - df = (rows - 1) * (columns - 1) - - if columns_not_zero == 0: - return results - # Results only of one type: workers are identical wrt the test - elif columns_not_zero == 1: - results = {'chi2': 0.0, 'dof': df, 'p': 1.0, 'residual': {}} - return results - # Results only of two types: workers are identical wrt the missing result type - # Change the data shape to avoid divide by zero - elif columns_not_zero == 2: - idx = numpy.argwhere(numpy.all(observed[..., :] == 0, axis=0)) - observed = numpy.delete(observed, idx, axis=1) - column_sums = numpy.sum(observed, axis=0) - - row_sums = numpy.sum(observed, axis=1) - grand_total = numpy.sum(column_sums) - - expected = numpy.outer(row_sums, column_sums) / grand_total - raw_residual = observed - expected - std_error = numpy.sqrt(expected * - numpy.outer((1 - row_sums / grand_total), - (1 - column_sums / grand_total))) - adj_residual = raw_residual / std_error - for idx in range(len(users)): - users[list(users.keys())[idx]] = numpy.max(numpy.abs(adj_residual[idx])) - chi2 = numpy.sum(raw_residual * raw_residual / expected) - return { - 'chi2': chi2, - 'dof': df, - 'p': 1 - scipy.stats.chi2.cdf(chi2, df), - 'residual': users, - } def calculate_residuals(run): - bad_users = set() - chi2 = get_chi2(run['tasks'], bad_users) - residuals = chi2['residual'] - - # Limit bad users to 1 for now - for _ in range(1): - worst_user = {} - for task in run['tasks']: - if task['worker_key'] in bad_users: - continue - task['residual'] = residuals.get(task['worker_key'], 0.0) - - # Special case crashes or time losses - stats = task.get('stats', {}) - crashes = stats.get('crashes', 0) - if crashes > 3: - task['residual'] = 8.0 - - if abs(task['residual']) < 2.0: - task['residual_color'] = '#44EB44' - elif abs(task['residual']) < 2.7: - task['residual_color'] = 'yellow' - else: - task['residual_color'] = '#FF6A6A' - - if chi2['p'] < 0.001 or task['residual'] > 7.0: - if len(worst_user) == 0 or task['residual'] > worst_user['residual']: - worst_user['worker_key'] = task['worker_key'] - worst_user['residual'] = task['residual'] - - if len(worst_user) == 0: - break - bad_users.add(worst_user['worker_key']) - residuals = get_chi2(run['tasks'], bad_users)['residual'] - - chi2['bad_users'] = bad_users - return chi2 + bad_users = set() + chi2 = get_chi2(run["tasks"], bad_users) + residuals = chi2["residual"] + + # Limit bad users to 1 for now + for _ in range(1): + worst_user = {} + for task in run["tasks"]: + if task["worker_key"] in bad_users: + continue + task["residual"] = residuals.get(task["worker_key"], 0.0) + + # Special case crashes or time losses + stats = task.get("stats", {}) + crashes = stats.get("crashes", 0) + if crashes > 3: + task["residual"] = 8.0 + + if abs(task["residual"]) < 2.0: + task["residual_color"] = "#44EB44" + elif abs(task["residual"]) < 2.7: + task["residual_color"] = "yellow" + else: + task["residual_color"] = "#FF6A6A" + + if chi2["p"] < 0.001 or task["residual"] > 7.0: + if len(worst_user) == 0 or task["residual"] > worst_user["residual"]: + worst_user["worker_key"] = task["worker_key"] + worst_user["residual"] = task["residual"] + + if len(worst_user) == 0: + break + bad_users.add(worst_user["worker_key"]) + residuals = get_chi2(run["tasks"], bad_users)["residual"] + + chi2["bad_users"] = bad_users + return chi2 def format_results(run_results, run): - result = {'style': '', 'info': []} - - # win/loss/draw count - WLD = [run_results['wins'], run_results['losses'], run_results['draws']] - - if 'spsa' in run['args']: - result['info'].append('%d/%d iterations' - % (run['args']['spsa']['iter'], - run['args']['spsa']['num_iter'])) - result['info'].append('%d/%d games played' - % (WLD[0] + WLD[1] + WLD[2], - run['args']['num_games'])) - return result - - # If the score is 0% or 100% the formulas will crash - # anyway the statistics are only asymptotic - if WLD[0] == 0 or WLD[1] == 0: - result['info'].append('Pending...') - return result - - state = 'unknown' - if 'sprt' in run['args']: - sprt = run['args']['sprt'] - state = sprt.get('state', '') - elo_model = sprt.get('elo_model', 'BayesElo') - if not 'llr' in sprt: # legacy - fishtest.stats.stat_util.update_SPRT(run_results,sprt) - if elo_model == 'BayesElo': - result['info'].append('LLR: %.2f (%.2lf,%.2lf) [%.2f,%.2f]' - % (sprt['llr'], - sprt['lower_bound'], sprt['upper_bound'], - sprt['elo0'], sprt['elo1'])) + result = {"style": "", "info": []} + + # win/loss/draw count + WLD = [run_results["wins"], run_results["losses"], run_results["draws"]] + + if "spsa" in run["args"]: + result["info"].append( + "%d/%d iterations" + % (run["args"]["spsa"]["iter"], run["args"]["spsa"]["num_iter"]) + ) + result["info"].append( + "%d/%d games played" % (WLD[0] + WLD[1] + WLD[2], run["args"]["num_games"]) + ) + return result + + # If the score is 0% or 100% the formulas will crash + # anyway the statistics are only asymptotic + if WLD[0] == 0 or WLD[1] == 0: + result["info"].append("Pending...") + return result + + state = "unknown" + if "sprt" in run["args"]: + sprt = run["args"]["sprt"] + state = sprt.get("state", "") + elo_model = sprt.get("elo_model", "BayesElo") + if not "llr" in sprt: # legacy + fishtest.stats.stat_util.update_SPRT(run_results, sprt) + if elo_model == "BayesElo": + result["info"].append( + "LLR: %.2f (%.2lf,%.2lf) [%.2f,%.2f]" + % ( + sprt["llr"], + sprt["lower_bound"], + sprt["upper_bound"], + sprt["elo0"], + sprt["elo1"], + ) + ) + else: + result["info"].append( + "LLR: %.2f (%.2lf,%.2lf) {%.2f,%.2f}" + % ( + sprt["llr"], + sprt["lower_bound"], + sprt["upper_bound"], + sprt["elo0"], + sprt["elo1"], + ) + ) else: - result['info'].append('LLR: %.2f (%.2lf,%.2lf) {%.2f,%.2f}' - % (sprt['llr'], - sprt['lower_bound'], sprt['upper_bound'], - sprt['elo0'], sprt['elo1'])) - else: - if 'pentanomial' in run_results.keys(): - elo, elo95, los = fishtest.stats.stat_util.get_elo( - run_results['pentanomial']) - else: - elo, elo95, los = fishtest.stats.stat_util.get_elo( - [WLD[1], WLD[2], WLD[0]]) - - # Display the results - eloInfo = 'ELO: %.2f +-%.1f (95%%)' % (elo, elo95) - losInfo = 'LOS: %.1f%%' % (los * 100) - - result['info'].append(eloInfo + ' ' + losInfo) - - if los < 0.05: - state = 'rejected' - elif los > 0.95: - state = 'accepted' - - result['info'].append('Total: %d W: %d L: %d D: %d' - % (sum(WLD), WLD[0], WLD[1], WLD[2])) - if 'pentanomial' in run_results.keys(): - result['info'].append("Ptnml(0-2): " + ", ".join( - str(run_results['pentanomial'][i]) for i in range(0, 5))) - - if state == 'rejected': - if WLD[0] > WLD[1]: - result['style'] = 'yellow' - else: - result['style'] = '#FF6A6A' - elif state == 'accepted': - if ('sprt' in run['args'] - and (float(sprt['elo0']) + float(sprt['elo1'])) < 0.0): - result['style'] = '#66CCFF' - else: - result['style'] = '#44EB44' - return result - + if "pentanomial" in run_results.keys(): + elo, elo95, los = fishtest.stats.stat_util.get_elo( + run_results["pentanomial"] + ) + else: + elo, elo95, los = fishtest.stats.stat_util.get_elo([WLD[1], WLD[2], WLD[0]]) + + # Display the results + eloInfo = "ELO: %.2f +-%.1f (95%%)" % (elo, elo95) + losInfo = "LOS: %.1f%%" % (los * 100) + + result["info"].append(eloInfo + " " + losInfo) + + if los < 0.05: + state = "rejected" + elif los > 0.95: + state = "accepted" + + result["info"].append( + "Total: %d W: %d L: %d D: %d" % (sum(WLD), WLD[0], WLD[1], WLD[2]) + ) + if "pentanomial" in run_results.keys(): + result["info"].append( + "Ptnml(0-2): " + + ", ".join(str(run_results["pentanomial"][i]) for i in range(0, 5)) + ) + + if state == "rejected": + if WLD[0] > WLD[1]: + result["style"] = "yellow" + else: + result["style"] = "#FF6A6A" + elif state == "accepted": + if "sprt" in run["args"] and (float(sprt["elo0"]) + float(sprt["elo1"])) < 0.0: + result["style"] = "#66CCFF" + else: + result["style"] = "#44EB44" + return result def estimate_game_duration(tc): - # Total time for a game is assumed to be the double of tc for each player - # reduced for 92% because on average a game is stopped earlier (LTC fishtest result). - scale = 2 * 0.92 - # estimated number of moves per game (LTC fishtest result) - game_moves = 68 - - chunks = tc.split('+') - increment = 0.0 - if len(chunks) == 2: - increment = float(chunks[1]) - - chunks = chunks[0].split('/') - num_moves = 0 - if len(chunks) == 2: - num_moves = int(chunks[0]) - - time_tc = chunks[-1] - chunks = time_tc.split(':') - if len(chunks) == 2: - time_tc = float(chunks[0]) * 60 + float(chunks[1]) - else: - time_tc = float(chunks[0]) + # Total time for a game is assumed to be the double of tc for each player + # reduced for 92% because on average a game is stopped earlier (LTC fishtest result). + scale = 2 * 0.92 + # estimated number of moves per game (LTC fishtest result) + game_moves = 68 + + chunks = tc.split("+") + increment = 0.0 + if len(chunks) == 2: + increment = float(chunks[1]) + + chunks = chunks[0].split("/") + num_moves = 0 + if len(chunks) == 2: + num_moves = int(chunks[0]) + + time_tc = chunks[-1] + chunks = time_tc.split(":") + if len(chunks) == 2: + time_tc = float(chunks[0]) * 60 + float(chunks[1]) + else: + time_tc = float(chunks[0]) - if num_moves > 0: - time_tc = time_tc * (game_moves / num_moves) + if num_moves > 0: + time_tc = time_tc * (game_moves / num_moves) - return (time_tc + (increment * game_moves)) * scale + return (time_tc + (increment * game_moves)) * scale def remaining_hours(run): - r = run['results'] - if 'sprt' in run['args']: - # current average number of games. Regularly update / have server guess? - expected_games = 53000 - # checking randomly, half the expected games needs still to be done - remaining_games = expected_games / 2 - else: - expected_games = run['args']['num_games'] - remaining_games = max(0, - expected_games - - r['wins'] - r['losses'] - r['draws']) - game_secs = estimate_game_duration(run['args']['tc']) - return game_secs * remaining_games * int( - run['args'].get('threads', 1)) / (60*60) + r = run["results"] + if "sprt" in run["args"]: + # current average number of games. Regularly update / have server guess? + expected_games = 53000 + # checking randomly, half the expected games needs still to be done + remaining_games = expected_games / 2 + else: + expected_games = run["args"]["num_games"] + remaining_games = max(0, expected_games - r["wins"] - r["losses"] - r["draws"]) + game_secs = estimate_game_duration(run["args"]["tc"]) + return game_secs * remaining_games * int(run["args"].get("threads", 1)) / (60 * 60) def post_in_fishcooking_results(run): - """ Posts the results of the run to the fishcooking forum: + """ Posts the results of the run to the fishcooking forum: https://groups.google.com/forum/?fromgroups=#!forum/fishcooking """ - title = run['args']['new_tag'][:23] + title = run["args"]["new_tag"][:23] - if 'username' in run['args']: - title += ' (' + run['args']['username'] + ')' + if "username" in run["args"]: + title += " (" + run["args"]["username"] + ")" - body = FISH_URL + '%s\n\n' % (str(run['_id'])) + body = FISH_URL + "%s\n\n" % (str(run["_id"])) - body += run['start_time'].strftime("%d-%m-%y") + ' from ' - body += run['args'].get('username', '') + '\n\n' + body += run["start_time"].strftime("%d-%m-%y") + " from " + body += run["args"].get("username", "") + "\n\n" - body += run['args']['new_tag'] + ': ' + run['args'].get( - 'msg_new', '') + '\n' - body += run['args']['base_tag'] + ': ' + run['args'].get( - 'msg_base', '') + '\n\n' + body += run["args"]["new_tag"] + ": " + run["args"].get("msg_new", "") + "\n" + body += run["args"]["base_tag"] + ": " + run["args"].get("msg_base", "") + "\n\n" - body += 'TC: ' + run['args']['tc'] + ' th ' + str( - run['args'].get('threads', 1)) + '\n' - body += '\n'.join(run['results_info']['info']) + '\n\n' + body += ( + "TC: " + run["args"]["tc"] + " th " + str(run["args"].get("threads", 1)) + "\n" + ) + body += "\n".join(run["results_info"]["info"]) + "\n\n" - body += run['args'].get('info', '') + '\n\n' + body += run["args"].get("info", "") + "\n\n" - msg = MIMEText(body) - msg['Subject'] = title - msg['From'] = 'fishtest@noreply.github.com' - msg['To'] = 'fishcooking_results@googlegroups.com' + msg = MIMEText(body) + msg["Subject"] = title + msg["From"] = "fishtest@noreply.github.com" + msg["To"] = "fishcooking_results@googlegroups.com" - try: - s = smtplib.SMTP('localhost') - s.sendmail(msg['From'], [msg['To']], msg.as_string()) - s.quit() - except ConnectionRefusedError: - print('Unable to post results to fishcooking forum') + try: + s = smtplib.SMTP("localhost") + s.sendmail(msg["From"], [msg["To"]], msg.as_string()) + s.quit() + except ConnectionRefusedError: + print("Unable to post results to fishcooking forum") def delta_date(date): - if date != datetime.min: - diff = datetime.utcnow() - date - if diff.days != 0: - delta = '%d days ago' % (diff.days) - elif diff.seconds / 3600 > 1: - delta = '%d hours ago' % (diff.seconds / 3600) - elif diff.seconds / 60 > 1: - delta = '%d minutes ago' % (diff.seconds / 60) + if date != datetime.min: + diff = datetime.utcnow() - date + if diff.days != 0: + delta = "%d days ago" % (diff.days) + elif diff.seconds / 3600 > 1: + delta = "%d hours ago" % (diff.seconds / 3600) + elif diff.seconds / 60 > 1: + delta = "%d minutes ago" % (diff.seconds / 60) + else: + delta = "seconds ago" else: - delta = 'seconds ago' - else: - delta = 'Never' - return delta + delta = "Never" + return delta diff --git a/fishtest/fishtest/views.py b/fishtest/fishtest/views.py index d93eae48f..414f65edb 100644 --- a/fishtest/fishtest/views.py +++ b/fishtest/fishtest/views.py @@ -14,1010 +14,1135 @@ from pyramid.response import Response import fishtest.stats.stat_util -from fishtest.util import calculate_residuals, format_results, estimate_game_duration, delta_date +from fishtest.util import ( + calculate_residuals, + format_results, + estimate_game_duration, + delta_date, +) def clear_cache(): - global last_time, last_tests - building.acquire() - last_time = 0 - last_tests = None - building.release() + global last_time, last_tests + building.acquire() + last_time = 0 + last_tests = None + building.release() def cached_flash(request, requestString): - clear_cache() - request.session.flash(requestString) - return + clear_cache() + request.session.flash(requestString) + return -@view_config(route_name='home') +@view_config(route_name="home") def home(request): - return HTTPFound(location=request.route_url('tests')) + return HTTPFound(location=request.route_url("tests")) -@view_config(route_name='login', renderer='login.mak', - require_csrf=True, request_method=('GET', 'POST')) -@forbidden_view_config(renderer='login.mak') +@view_config( + route_name="login", + renderer="login.mak", + require_csrf=True, + request_method=("GET", "POST"), +) +@forbidden_view_config(renderer="login.mak") def login(request): - login_url = request.route_url('login') - referrer = request.url - if referrer == login_url: - referrer = '/' # never use the login form itself as came_from - came_from = request.params.get('came_from', referrer) - - if request.method == 'POST': - username = request.POST.get('username') - password = request.POST.get('password') - token = request.userdb.authenticate(username, password) - if 'error' not in token: - if request.POST.get('stay_logged_in'): - # Session persists for a year after login - headers = remember(request, username, max_age=60 * 60 * 24 * 365) - else: - # Session ends when the browser is closed - headers = remember(request, username) - next_page = request.params.get('next') or came_from - return HTTPFound(location=next_page, headers=headers) - - request.session.flash(token['error'], 'error') # 'Incorrect password' - return {} + login_url = request.route_url("login") + referrer = request.url + if referrer == login_url: + referrer = "/" # never use the login form itself as came_from + came_from = request.params.get("came_from", referrer) + + if request.method == "POST": + username = request.POST.get("username") + password = request.POST.get("password") + token = request.userdb.authenticate(username, password) + if "error" not in token: + if request.POST.get("stay_logged_in"): + # Session persists for a year after login + headers = remember(request, username, max_age=60 * 60 * 24 * 365) + else: + # Session ends when the browser is closed + headers = remember(request, username) + next_page = request.params.get("next") or came_from + return HTTPFound(location=next_page, headers=headers) + + request.session.flash(token["error"], "error") # 'Incorrect password' + return {} # Guard against upload timeouts/retries uploading = threading.Semaphore() -@view_config(route_name='nn_upload', renderer='nn_upload.mak', - require_csrf=True) + +@view_config(route_name="nn_upload", renderer="nn_upload.mak", require_csrf=True) def upload(request): - if not uploading.acquire(False): - request.session.flash( - 'An other upload is in progress, please try again later', 'error') - return {} - result = sync_upload(request) - uploading.release() - return result + if not uploading.acquire(False): + request.session.flash( + "An other upload is in progress, please try again later", "error" + ) + return {} + result = sync_upload(request) + uploading.release() + return result -def sync_upload(request): - userid = authenticated_userid(request) - if not userid: - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - if request.method != 'POST': - return {} - try: - filename = request.POST['network'].filename - input_file = request.POST['network'].file - network = input_file.read() - errors = [] - if len(network) >= 100000000: - errors.append('Network must be < 100MB') - if not re.match(r"^nn-[0-9a-f]{12}\.nnue$", filename): - errors.append('Name must match "nn-[SHA256 first 12 digits].nnue"') - hash = hashlib.sha256(network).hexdigest() - if hash[:12] != filename[3:15]: - errors.append('Wrong SHA256 hash: ' + hash[:12] - + ' Filename: ' + filename[3:15]) - if request.rundb.get_nn(filename): - errors.append('Network already exists') - if errors: - for error in errors: - request.session.flash(error, 'error') - return {} - except: - request.session.flash('You must specify a network filename', 'error') - return {} - try: - with open(os.path.expanduser('~/fishtest.upload'), 'r') as f: - upload_server = f.read().replace('\n', '') - upload_server = upload_server + '/' + filename - response = requests.post(upload_server, data=network) - if response.status_code != 200: - print('Network upload failed: ' + str(response.status_code)) - request.session.flash('Network upload failed: ' - + str(response.status_code), 'error') +def sync_upload(request): + userid = authenticated_userid(request) + if not userid: + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + if request.method != "POST": + return {} + try: + filename = request.POST["network"].filename + input_file = request.POST["network"].file + network = input_file.read() + errors = [] + if len(network) >= 100000000: + errors.append("Network must be < 100MB") + if not re.match(r"^nn-[0-9a-f]{12}\.nnue$", filename): + errors.append('Name must match "nn-[SHA256 first 12 digits].nnue"') + hash = hashlib.sha256(network).hexdigest() + if hash[:12] != filename[3:15]: + errors.append( + "Wrong SHA256 hash: " + hash[:12] + " Filename: " + filename[3:15] + ) + if request.rundb.get_nn(filename): + errors.append("Network already exists") + if errors: + for error in errors: + request.session.flash(error, "error") + return {} + except: + request.session.flash("You must specify a network filename", "error") return {} - except Exception as e: - print("NN Upload fails or not configured: " + str(e)) - request.actiondb.upload_nn(authenticated_userid(request), filename) - request.rundb.upload_nn(userid, filename, network) + try: + with open(os.path.expanduser("~/fishtest.upload"), "r") as f: + upload_server = f.read().replace("\n", "") + upload_server = upload_server + "/" + filename + response = requests.post(upload_server, data=network) + if response.status_code != 200: + print("Network upload failed: " + str(response.status_code)) + request.session.flash( + "Network upload failed: " + str(response.status_code), "error" + ) + return {} + except Exception as e: + print("NN Upload fails or not configured: " + str(e)) - return HTTPFound(location=request.route_url('nns')) + request.actiondb.upload_nn(authenticated_userid(request), filename) + request.rundb.upload_nn(userid, filename, network) -@view_config(route_name='logout', require_csrf=True, request_method='POST') -def logout(request): - session = request.session - headers = forget(request) - session.invalidate() - return HTTPFound(location=request.route_url('tests'), headers=headers) + return HTTPFound(location=request.route_url("nns")) -@view_config(route_name='signup', renderer='signup.mak', - require_csrf=True, request_method=('GET', 'POST')) +@view_config(route_name="logout", require_csrf=True, request_method="POST") +def logout(request): + session = request.session + headers = forget(request) + session.invalidate() + return HTTPFound(location=request.route_url("tests"), headers=headers) + + +@view_config( + route_name="signup", + renderer="signup.mak", + require_csrf=True, + request_method=("GET", "POST"), +) def signup(request): - if request.method != 'POST': - return {} - errors = [] - if len(request.POST.get('password', '')) == 0: - errors.append('Non-empty password required') - if request.POST.get('password') != request.POST.get('password2', ''): - errors.append('Matching verify password required') - if '@' not in request.POST.get('email', ''): - errors.append('Email required') - if len(request.POST.get('username', '')) == 0: - errors.append('Username required') - if not request.POST.get('username', '').isalnum(): - errors.append('Alphanumeric username required') - if errors: - for error in errors: - request.session.flash(error, 'error') + if request.method != "POST": + return {} + errors = [] + if len(request.POST.get("password", "")) == 0: + errors.append("Non-empty password required") + if request.POST.get("password") != request.POST.get("password2", ""): + errors.append("Matching verify password required") + if "@" not in request.POST.get("email", ""): + errors.append("Email required") + if len(request.POST.get("username", "")) == 0: + errors.append("Username required") + if not request.POST.get("username", "").isalnum(): + errors.append("Alphanumeric username required") + if errors: + for error in errors: + request.session.flash(error, "error") + return {} + + path = os.path.expanduser("~/fishtest.captcha.secret") + if os.path.exists(path): + with open(path, "r") as f: + secret = f.read() + payload = { + "secret": secret, + "response": request.POST.get("g-recaptcha-response", ""), + "remoteip": request.remote_addr, + } + response = requests.post( + "https://www.google.com/recaptcha/api/siteverify", data=payload + ).json() + if "success" not in response or not response["success"]: + if "error-codes" in response: + print(response["error-codes"]) + request.session.flash("Captcha failed", "error") + return {} + + result = request.userdb.create_user( + username=request.POST.get("username", ""), + password=request.POST.get("password", ""), + email=request.POST.get("email", ""), + ) + if not result: + request.session.flash("Invalid username", "error") + else: + request.session.flash( + "Your account has been created, but will be activated by a human. This might take a few hours. Thank you for contributing!" + ) + return HTTPFound(location=request.route_url("login")) return {} - path = os.path.expanduser('~/fishtest.captcha.secret') - if os.path.exists(path): - with open(path, 'r') as f: - secret = f.read() - payload = {'secret': secret, - 'response': request.POST.get('g-recaptcha-response', ''), - 'remoteip': request.remote_addr} - response = requests.post( - 'https://www.google.com/recaptcha/api/siteverify', - data=payload).json() - if 'success' not in response or not response['success']: - if 'error-codes' in response: - print(response['error-codes']) - request.session.flash('Captcha failed', 'error') - return {} - result = request.userdb.create_user( - username=request.POST.get('username', ''), - password=request.POST.get('password', ''), - email=request.POST.get('email', '') - ) - if not result: - request.session.flash('Invalid username', 'error') - else: - request.session.flash( - 'Your account has been created, but will be activated by a human. This might take a few hours. Thank you for contributing!') - return HTTPFound(location=request.route_url('login')) - return {} - - -@view_config(route_name='nns', renderer='nns.mak') +@view_config(route_name="nns", renderer="nns.mak") def nns(request): - nns_list = [] + nns_list = [] - for nn in request.rundb.get_nns(100): - nns_list.append(nn) - return {'nns': nns_list} + for nn in request.rundb.get_nns(100): + nns_list.append(nn) + return {"nns": nns_list} -@view_config(route_name='actions', renderer='actions.mak') +@view_config(route_name="actions", renderer="actions.mak") def actions(request): - search_action = request.params.get('action', '') - search_user = request.params.get('user', '') - - actions_list = [] - for action in request.actiondb.get_actions(100, search_action, search_user): - item = { - 'action': action['action'], - 'time': action['time'], - 'username': action['username'], + search_action = request.params.get("action", "") + search_user = request.params.get("user", "") + + actions_list = [] + for action in request.actiondb.get_actions(100, search_action, search_user): + item = { + "action": action["action"], + "time": action["time"], + "username": action["username"], + } + if action["action"] == "update_stats": + item["user"] = "" + item["description"] = "Update user statistics" + elif action["action"] == "upload_nn": + item["user"] = "" + item["description"] = "Upload " + action["data"] + elif action["action"] == "block_user": + item["description"] = ( + "blocked" if action["data"]["blocked"] else "unblocked" + ) + item["user"] = action["data"]["user"] + elif action["action"] == "modify_run": + item["run"] = action["data"]["before"]["args"]["new_tag"] + item["_id"] = action["data"]["before"]["_id"] + item["description"] = [] + + before = action["data"]["before"]["args"]["priority"] + after = action["data"]["after"]["args"]["priority"] + if before != after: + item["description"].append( + "priority changed from {} to {}".format(before, after) + ) + + before = action["data"]["before"]["args"]["num_games"] + after = action["data"]["after"]["args"]["num_games"] + if before != after: + item["description"].append( + "games changed from {} to {}".format(before, after) + ) + + before = action["data"]["before"]["args"]["throughput"] + after = action["data"]["after"]["args"]["throughput"] + if before != after: + item["description"].append( + "throughput changed from {} to {}".format(before, after) + ) + + before = action["data"]["before"]["args"]["auto_purge"] + after = action["data"]["after"]["args"]["auto_purge"] + if before != after: + item["description"].append( + "auto-purge changed from {} to {}".format(before, after) + ) + + item["description"] = "modify: " + ", ".join(item["description"]) + else: + item["run"] = action["data"]["args"]["new_tag"] + item["_id"] = action["data"]["_id"] + item["description"] = " ".join(action["action"].split("_")) + if action["action"] == "stop_run": + item["description"] += ": {}".format( + action["data"].get("stop_reason", "User stop") + ) + + actions_list.append(item) + + return { + "actions": actions_list, + "approver": has_permission("approve_run", request.context, request), } - if action['action'] == 'update_stats': - item['user'] = '' - item['description'] = 'Update user statistics' - elif action['action'] == 'upload_nn': - item['user'] = '' - item['description'] = 'Upload ' + action['data'] - elif action['action'] == 'block_user': - item['description'] = ( - 'blocked' if action['data']['blocked'] else 'unblocked') - item['user'] = action['data']['user'] - elif action['action'] == 'modify_run': - item['run'] = action['data']['before']['args']['new_tag'] - item['_id'] = action['data']['before']['_id'] - item['description'] = [] - - before = action['data']['before']['args']['priority'] - after = action['data']['after']['args']['priority'] - if before != after: - item['description'].append( - 'priority changed from {} to {}'.format(before, after)) - - before = action['data']['before']['args']['num_games'] - after = action['data']['after']['args']['num_games'] - if before != after: - item['description'].append( - 'games changed from {} to {}'.format(before, after)) - - before = action['data']['before']['args']['throughput'] - after = action['data']['after']['args']['throughput'] - if before != after: - item['description'].append( - 'throughput changed from {} to {}'.format(before, after)) - - before = action['data']['before']['args']['auto_purge'] - after = action['data']['after']['args']['auto_purge'] - if before != after: - item['description'].append( - 'auto-purge changed from {} to {}'.format(before, after)) - - item['description'] = 'modify: ' + ', '.join(item['description']) - else: - item['run'] = action['data']['args']['new_tag'] - item['_id'] = action['data']['_id'] - item['description'] = ' '.join(action['action'].split('_')) - if action['action'] == 'stop_run': - item['description'] += ': {}'.format( - action['data'].get('stop_reason', 'User stop')) - - actions_list.append(item) - - return {'actions': actions_list, - 'approver': has_permission('approve_run', request.context, request)} def get_idle_users(request): - idle = {} - for u in request.userdb.get_users(): - idle[u['username']] = u - for u in request.userdb.user_cache.find(): - del idle[u['username']] - idle = list(idle.values()) - return idle + idle = {} + for u in request.userdb.get_users(): + idle[u["username"]] = u + for u in request.userdb.user_cache.find(): + del idle[u["username"]] + idle = list(idle.values()) + return idle -@view_config(route_name='pending', renderer='pending.mak') +@view_config(route_name="pending", renderer="pending.mak") def pending(request): - if not has_permission('approve_run', request.context, request): - request.session.flash('You cannot view pending users', 'error') - return HTTPFound(location=request.route_url('tests')) + if not has_permission("approve_run", request.context, request): + request.session.flash("You cannot view pending users", "error") + return HTTPFound(location=request.route_url("tests")) - return {'users': request.userdb.get_pending(), - 'idle': get_idle_users(request)} + return {"users": request.userdb.get_pending(), "idle": get_idle_users(request)} -@view_config(route_name='user', renderer='user.mak') -@view_config(route_name='profile', renderer='user.mak') +@view_config(route_name="user", renderer="user.mak") +@view_config(route_name="profile", renderer="user.mak") def user(request): - userid = authenticated_userid(request) - if not userid: - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - user_name = request.matchdict.get('username', userid) - profile = (user_name == userid) - if not profile and not has_permission( - 'approve_run', request.context, request): - request.session.flash('You cannot inspect users', 'error') - return HTTPFound(location=request.route_url('tests')) - user_data = request.userdb.get_user(user_name) - if 'user' in request.POST: - if profile: - if len(request.params.get('password')) > 0: - if (request.params.get('password') - != request.params.get('password2', '')): - request.session.flash('Matching verify password required', 'error') - return {'user': user_data, 'profile': profile} - user_data['password'] = request.params.get('password') - if len(request.params.get('email')) > 0: - user_data['email'] = request.params.get('email') - else: - user_data['blocked'] = ('blocked' in request.POST) - request.userdb.last_pending_time = 0 - request.actiondb.block_user(authenticated_userid(request), - {'user': user_name, 'blocked': user_data['blocked']}) - request.session.flash(('Blocked' if user_data['blocked'] else 'Unblocked') - + ' user ' + user_name) - request.userdb.save_user(user_data) - return HTTPFound(location=request.route_url('tests')) - userc = request.userdb.user_cache.find_one({'username': user_name}) - hours = int(userc['cpu_hours']) if userc is not None else 0 - return {'user': user_data, 'limit': request.userdb.get_machine_limit(user_name), - 'hours': hours, 'profile': profile} - - -@view_config(route_name='users', renderer='users.mak') + userid = authenticated_userid(request) + if not userid: + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + user_name = request.matchdict.get("username", userid) + profile = user_name == userid + if not profile and not has_permission("approve_run", request.context, request): + request.session.flash("You cannot inspect users", "error") + return HTTPFound(location=request.route_url("tests")) + user_data = request.userdb.get_user(user_name) + if "user" in request.POST: + if profile: + if len(request.params.get("password")) > 0: + if request.params.get("password") != request.params.get( + "password2", "" + ): + request.session.flash("Matching verify password required", "error") + return {"user": user_data, "profile": profile} + user_data["password"] = request.params.get("password") + if len(request.params.get("email")) > 0: + user_data["email"] = request.params.get("email") + else: + user_data["blocked"] = "blocked" in request.POST + request.userdb.last_pending_time = 0 + request.actiondb.block_user( + authenticated_userid(request), + {"user": user_name, "blocked": user_data["blocked"]}, + ) + request.session.flash( + ("Blocked" if user_data["blocked"] else "Unblocked") + + " user " + + user_name + ) + request.userdb.save_user(user_data) + return HTTPFound(location=request.route_url("tests")) + userc = request.userdb.user_cache.find_one({"username": user_name}) + hours = int(userc["cpu_hours"]) if userc is not None else 0 + return { + "user": user_data, + "limit": request.userdb.get_machine_limit(user_name), + "hours": hours, + "profile": profile, + } + + +@view_config(route_name="users", renderer="users.mak") def users(request): - users_list = list(request.userdb.user_cache.find()) - users_list.sort(key=lambda k: k['cpu_hours'], reverse=True) - return {'users': users_list} + users_list = list(request.userdb.user_cache.find()) + users_list.sort(key=lambda k: k["cpu_hours"], reverse=True) + return {"users": users_list} -@view_config(route_name='users_monthly', renderer='users.mak') +@view_config(route_name="users_monthly", renderer="users.mak") def users_monthly(request): - users_list = list(request.userdb.top_month.find()) - users_list.sort(key=lambda k: k['cpu_hours'], reverse=True) - return {'users': users_list} + users_list = list(request.userdb.top_month.find()) + users_list.sort(key=lambda k: k["cpu_hours"], reverse=True) + return {"users": users_list} def get_master_bench(): - bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) - for c in requests.get( - 'https://api.github.com/repos/official-stockfish/Stockfish/commits').json(): - if not 'commit' in c: - return None - m = bs.search(c['commit']['message']) - if m: - return m.group(2) - return None + bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) + for c in requests.get( + "https://api.github.com/repos/official-stockfish/Stockfish/commits" + ).json(): + if not "commit" in c: + return None + m = bs.search(c["commit"]["message"]) + if m: + return m.group(2) + return None def get_sha(branch, repo_url): - """ Resolves the git branch to sha commit """ - api_url = repo_url.replace('https://github.com', - 'https://api.github.com/repos') - try: - commit = requests.get(api_url + '/commits/' + branch).json() - except: - raise Exception("Unable to access developer repository") - if 'sha' in commit: - return commit['sha'], commit['commit']['message'].split('\n')[0] - else: - return '', '' + """ Resolves the git branch to sha commit """ + api_url = repo_url.replace("https://github.com", "https://api.github.com/repos") + try: + commit = requests.get(api_url + "/commits/" + branch).json() + except: + raise Exception("Unable to access developer repository") + if "sha" in commit: + return commit["sha"], commit["commit"]["message"].split("\n")[0] + else: + return "", "" def get_net(branch, repo_url): - """ Get the net from ucioption.cpp in the repo """ - api_url = repo_url.replace('https://github.com', - 'https://raw.githubusercontent.com') - try: - api_url = api_url + '/' + branch + '/src/ucioption.cpp' - options = requests.get(api_url).content.decode('utf-8') - net = None - for line in options.splitlines(): - if 'EvalFile' in line and 'Option' in line: - p = re.compile('nn-[a-z0-9]{12}.nnue') - m = p.search(line) - if m: - net = m.group(0) - return net - except: - raise Exception("Unable to access developer repository: " + api_url) + """ Get the net from ucioption.cpp in the repo """ + api_url = repo_url.replace( + "https://github.com", "https://raw.githubusercontent.com" + ) + try: + api_url = api_url + "/" + branch + "/src/ucioption.cpp" + options = requests.get(api_url).content.decode("utf-8") + net = None + for line in options.splitlines(): + if "EvalFile" in line and "Option" in line: + p = re.compile("nn-[a-z0-9]{12}.nnue") + m = p.search(line) + if m: + net = m.group(0) + return net + except: + raise Exception("Unable to access developer repository: " + api_url) + def parse_spsa_params(raw, spsa): - params = [] - for line in raw.split('\n'): - chunks = line.strip().split(',') - if len(chunks)==1 and chunks[0]=="": # blank line - continue - if len(chunks) != 6: - raise Exception('the line %s does not have 6 entries' % (chunks)) - param = { - 'name': chunks[0], - 'start': float(chunks[1]), - 'min': float(chunks[2]), - 'max': float(chunks[3]), - 'c_end': float(chunks[4]), - 'r_end': float(chunks[5]), - } - param['c'] = param['c_end'] * spsa['num_iter'] ** spsa['gamma'] - param['a_end'] = param['r_end'] * param['c_end'] ** 2 - param['a'] = param['a_end'] * (spsa['A'] + spsa['num_iter']) ** spsa['alpha'] - param['theta'] = param['start'] - params.append(param) - return params + params = [] + for line in raw.split("\n"): + chunks = line.strip().split(",") + if len(chunks) == 1 and chunks[0] == "": # blank line + continue + if len(chunks) != 6: + raise Exception("the line %s does not have 6 entries" % (chunks)) + param = { + "name": chunks[0], + "start": float(chunks[1]), + "min": float(chunks[2]), + "max": float(chunks[3]), + "c_end": float(chunks[4]), + "r_end": float(chunks[5]), + } + param["c"] = param["c_end"] * spsa["num_iter"] ** spsa["gamma"] + param["a_end"] = param["r_end"] * param["c_end"] ** 2 + param["a"] = param["a_end"] * (spsa["A"] + spsa["num_iter"]) ** spsa["alpha"] + param["theta"] = param["start"] + params.append(param) + return params def validate_form(request): - data = { - 'base_tag': request.POST['base-branch'], - 'new_tag': request.POST['test-branch'], - 'tc': request.POST['tc'], - 'book': request.POST['book'], - 'book_depth': request.POST['book-depth'], - 'base_signature': request.POST['base-signature'], - 'new_signature': request.POST['test-signature'], - 'base_options': request.POST['base-options'], - 'new_options': request.POST['new-options'], - 'username': authenticated_userid(request), - 'tests_repo': request.POST['tests-repo'], - 'info': request.POST['run-info'], - } - - if not re.match(r"^([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?$", data['tc']): - raise Exception('Bad time control format') - - if request.POST.get('rescheduled_from'): - data['rescheduled_from'] = request.POST['rescheduled_from'] - - def strip_message(m): - s = re.sub(r"[Bb]ench[ :]+[0-9]+\s*", "", m) - s = re.sub(r"[ \t]+", " ", s) - s = re.sub(r"\n+", r"\n", s) - return s.rstrip() - - # Fill new_signature/info from commit info if left blank - if len(data['new_signature']) == 0 or len(data['info']) == 0: - api_url = data['tests_repo'].replace('https://github.com', - 'https://api.github.com/repos') - api_url += ('/commits' + '/' + data['new_tag']) - try: - c = requests.get(api_url).json() - except: - raise Exception("Unable to access developer repository") - if 'commit' not in c: - raise Exception('Cannot find branch in developer repository') - if len(data['new_signature']) == 0: - bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) - m = bs.search(c['commit']['message']) - if m: - data['new_signature'] = m.group(2) - else: - raise Exception("This commit has no signature: please supply it manually.") - if len(data['info']) == 0: - data['info'] = ('' if re.match(r"^[012]?[0-9][^0-9].*", data['tc']) - else 'LTC: ') + strip_message(c['commit']['message']) - - # Check that the book exists in the official books repo - if len(data['book']) > 0: - api_url = 'https://api.github.com/repos/official-stockfish/books/contents' - c = requests.get(api_url).json() - matcher = re.compile(r"\.(epd|pgn)\.zip$") - valid_book_filenames = [file['name'] for file in c if matcher.search(file['name'])] - if data['book'] + '.zip' not in valid_book_filenames: - raise Exception('Invalid book - ' + data['book']) - - if request.POST['stop_rule']=='spsa': - data['base_signature']=data['new_signature'] - - for k,v in data.items(): - if len(v)==0: - raise Exception('Missing required option: %s' % k) - - data['auto_purge'] = request.POST.get('auto-purge') is not None - - # In case of reschedule use old data, - # otherwise resolve sha and update user's tests_repo - if 'resolved_base' in request.POST: - data['resolved_base'] = request.POST['resolved_base'] - data['resolved_new'] = request.POST['resolved_new'] - data['msg_base'] = request.POST['msg_base'] - data['msg_new'] = request.POST['msg_new'] - else: - data['resolved_base'], data['msg_base'] = get_sha( - data['base_tag'], data['tests_repo']) - data['resolved_new'], data['msg_new'] = get_sha( - data['new_tag'], data['tests_repo']) - u = request.userdb.get_user(data['username']) - if u.get('tests_repo', '') != data['tests_repo']: - u['tests_repo'] = data['tests_repo'] - request.userdb.users.save(u) - - if len(data['resolved_base']) == 0 or len(data['resolved_new']) == 0: - raise Exception('Unable to find branch!') - - # Check entered bench - if data['base_tag'] == 'master': - found = False - api_url = data['tests_repo'].replace('https://github.com', - 'https://api.github.com/repos') - api_url += '/commits' - bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) - for c in requests.get(api_url).json(): - m = bs.search(c['commit']['message']) - if m: - found = True - break - if not found or m.group(2) != data['base_signature']: - raise Exception('Bench signature of Base master does not match, ' - + 'please "git pull upstream master" !') - - stop_rule = request.POST['stop_rule'] - - # Check if the base branch of the test repo matches official master - api_url = 'https://api.github.com/repos/official-stockfish/Stockfish' - api_url += '/compare/master...' + data['resolved_base'][:10] - master_diff = requests.get(api_url, headers={ - 'Accept': 'application/vnd.github.v3.diff' - }) - data['base_same_as_master'] = master_diff.text is '' - - # Test existence of net - new_net = get_net(data['new_tag'], data['tests_repo']) - if new_net: - if not request.rundb.get_nn(new_net): - raise Exception("Net not in repository: " + new_net) - - # Store net info - data['new_net'] = new_net - data['base_net'] = get_net(data['base_tag'], data['tests_repo']) - - # Integer parameters - - if stop_rule == 'sprt': - sprt_batch_size_games=8 - assert(sprt_batch_size_games%2==0) - assert(request.rundb.chunk_size%sprt_batch_size_games==0) - data['sprt'] = fishtest.stats.stat_util.SPRT(alpha=0.05, - beta=0.05, - elo0=float(request.POST['sprt_elo0']), - elo1=float(request.POST['sprt_elo1']), - elo_model='logistic', - batch_size=sprt_batch_size_games//2) #game pairs - # Limit on number of games played. - # Shouldn't be hit in practice as long as it is larger than > ~200000 - # must scale with chunk_size to avoid overloading the server. - data['num_games'] = 2000 * request.rundb.chunk_size - elif stop_rule == 'spsa': - data['num_games'] = int(request.POST['num-games']) - if data['num_games'] <= 0: - raise Exception('Number of games must be >= 0') - - data['spsa'] = { - 'A': int(request.POST['spsa_A']), - 'alpha': float(request.POST['spsa_alpha']), - 'gamma': float(request.POST['spsa_gamma']), - 'raw_params': request.POST['spsa_raw_params'], - 'iter': 0, - 'num_iter': int(data['num_games'] / 2), - 'clipping': request.POST['spsa_clipping'], - 'rounding': request.POST['spsa_rounding'], + data = { + "base_tag": request.POST["base-branch"], + "new_tag": request.POST["test-branch"], + "tc": request.POST["tc"], + "book": request.POST["book"], + "book_depth": request.POST["book-depth"], + "base_signature": request.POST["base-signature"], + "new_signature": request.POST["test-signature"], + "base_options": request.POST["base-options"], + "new_options": request.POST["new-options"], + "username": authenticated_userid(request), + "tests_repo": request.POST["tests-repo"], + "info": request.POST["run-info"], } - data['spsa']['params'] = parse_spsa_params( - request.POST['spsa_raw_params'], data['spsa']) - if len(data['spsa']['params']) == 0: - raise Exception('Number of params must be > 0') - else: - data['num_games'] = int(request.POST['num-games']) - if data['num_games'] <= 0: - raise Exception('Number of games must be >= 0') - max_games = 4000 * request.rundb.chunk_size - if data['num_games'] > max_games: - raise Exception('Number of games must be <= ' + str(max_games)) + if not re.match(r"^([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?$", data["tc"]): + raise Exception("Bad time control format") + + if request.POST.get("rescheduled_from"): + data["rescheduled_from"] = request.POST["rescheduled_from"] + + def strip_message(m): + s = re.sub(r"[Bb]ench[ :]+[0-9]+\s*", "", m) + s = re.sub(r"[ \t]+", " ", s) + s = re.sub(r"\n+", r"\n", s) + return s.rstrip() + + # Fill new_signature/info from commit info if left blank + if len(data["new_signature"]) == 0 or len(data["info"]) == 0: + api_url = data["tests_repo"].replace( + "https://github.com", "https://api.github.com/repos" + ) + api_url += "/commits" + "/" + data["new_tag"] + try: + c = requests.get(api_url).json() + except: + raise Exception("Unable to access developer repository") + if "commit" not in c: + raise Exception("Cannot find branch in developer repository") + if len(data["new_signature"]) == 0: + bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) + m = bs.search(c["commit"]["message"]) + if m: + data["new_signature"] = m.group(2) + else: + raise Exception( + "This commit has no signature: please supply it manually." + ) + if len(data["info"]) == 0: + data["info"] = ( + "" if re.match(r"^[012]?[0-9][^0-9].*", data["tc"]) else "LTC: " + ) + strip_message(c["commit"]["message"]) + + # Check that the book exists in the official books repo + if len(data["book"]) > 0: + api_url = "https://api.github.com/repos/official-stockfish/books/contents" + c = requests.get(api_url).json() + matcher = re.compile(r"\.(epd|pgn)\.zip$") + valid_book_filenames = [ + file["name"] for file in c if matcher.search(file["name"]) + ] + if data["book"] + ".zip" not in valid_book_filenames: + raise Exception("Invalid book - " + data["book"]) + + if request.POST["stop_rule"] == "spsa": + data["base_signature"] = data["new_signature"] + + for k, v in data.items(): + if len(v) == 0: + raise Exception("Missing required option: %s" % k) + + data["auto_purge"] = request.POST.get("auto-purge") is not None + + # In case of reschedule use old data, + # otherwise resolve sha and update user's tests_repo + if "resolved_base" in request.POST: + data["resolved_base"] = request.POST["resolved_base"] + data["resolved_new"] = request.POST["resolved_new"] + data["msg_base"] = request.POST["msg_base"] + data["msg_new"] = request.POST["msg_new"] + else: + data["resolved_base"], data["msg_base"] = get_sha( + data["base_tag"], data["tests_repo"] + ) + data["resolved_new"], data["msg_new"] = get_sha( + data["new_tag"], data["tests_repo"] + ) + u = request.userdb.get_user(data["username"]) + if u.get("tests_repo", "") != data["tests_repo"]: + u["tests_repo"] = data["tests_repo"] + request.userdb.users.save(u) + + if len(data["resolved_base"]) == 0 or len(data["resolved_new"]) == 0: + raise Exception("Unable to find branch!") + + # Check entered bench + if data["base_tag"] == "master": + found = False + api_url = data["tests_repo"].replace( + "https://github.com", "https://api.github.com/repos" + ) + api_url += "/commits" + bs = re.compile(r"(^|\s)[Bb]ench[ :]+([0-9]+)", re.MULTILINE) + for c in requests.get(api_url).json(): + m = bs.search(c["commit"]["message"]) + if m: + found = True + break + if not found or m.group(2) != data["base_signature"]: + raise Exception( + "Bench signature of Base master does not match, " + + 'please "git pull upstream master" !' + ) + + stop_rule = request.POST["stop_rule"] + + # Check if the base branch of the test repo matches official master + api_url = "https://api.github.com/repos/official-stockfish/Stockfish" + api_url += "/compare/master..." + data["resolved_base"][:10] + master_diff = requests.get( + api_url, headers={"Accept": "application/vnd.github.v3.diff"} + ) + data["base_same_as_master"] = master_diff.text is "" + + # Test existence of net + new_net = get_net(data["new_tag"], data["tests_repo"]) + if new_net: + if not request.rundb.get_nn(new_net): + raise Exception("Net not in repository: " + new_net) + + # Store net info + data["new_net"] = new_net + data["base_net"] = get_net(data["base_tag"], data["tests_repo"]) + + # Integer parameters + + if stop_rule == "sprt": + sprt_batch_size_games = 8 + assert sprt_batch_size_games % 2 == 0 + assert request.rundb.chunk_size % sprt_batch_size_games == 0 + data["sprt"] = fishtest.stats.stat_util.SPRT( + alpha=0.05, + beta=0.05, + elo0=float(request.POST["sprt_elo0"]), + elo1=float(request.POST["sprt_elo1"]), + elo_model="logistic", + batch_size=sprt_batch_size_games // 2, + ) # game pairs + # Limit on number of games played. + # Shouldn't be hit in practice as long as it is larger than > ~200000 + # must scale with chunk_size to avoid overloading the server. + data["num_games"] = 2000 * request.rundb.chunk_size + elif stop_rule == "spsa": + data["num_games"] = int(request.POST["num-games"]) + if data["num_games"] <= 0: + raise Exception("Number of games must be >= 0") + + data["spsa"] = { + "A": int(request.POST["spsa_A"]), + "alpha": float(request.POST["spsa_alpha"]), + "gamma": float(request.POST["spsa_gamma"]), + "raw_params": request.POST["spsa_raw_params"], + "iter": 0, + "num_iter": int(data["num_games"] / 2), + "clipping": request.POST["spsa_clipping"], + "rounding": request.POST["spsa_rounding"], + } + data["spsa"]["params"] = parse_spsa_params( + request.POST["spsa_raw_params"], data["spsa"] + ) + if len(data["spsa"]["params"]) == 0: + raise Exception("Number of params must be > 0") + else: + data["num_games"] = int(request.POST["num-games"]) + if data["num_games"] <= 0: + raise Exception("Number of games must be >= 0") + + max_games = 4000 * request.rundb.chunk_size + if data["num_games"] > max_games: + raise Exception("Number of games must be <= " + str(max_games)) - data['threads'] = int(request.POST['threads']) - data['priority'] = int(request.POST['priority']) - data['throughput'] = int(request.POST['throughput']) + data["threads"] = int(request.POST["threads"]) + data["priority"] = int(request.POST["priority"]) + data["throughput"] = int(request.POST["throughput"]) - if data['threads'] <= 0: - raise Exception('Threads must be >= 1') + if data["threads"] <= 0: + raise Exception("Threads must be >= 1") - return data + return data def del_tasks(run): - if 'tasks' in run: - run = copy.deepcopy(run) - del run['tasks'] - return run + if "tasks" in run: + run = copy.deepcopy(run) + del run["tasks"] + return run def update_nets(request, run): - run_id = str(run['_id']) - data = run['args'] - if run['base_same_as_master']: - base_net = data['base_net'] - if base_net: - net = request.rundb.get_nn(base_net) - if not net: - # Should never happen: - raise Exception("Net not in repository: " + base_net) - if 'is_master' not in net: - net['is_master'] = True + run_id = str(run["_id"]) + data = run["args"] + if run["base_same_as_master"]: + base_net = data["base_net"] + if base_net: + net = request.rundb.get_nn(base_net) + if not net: + # Should never happen: + raise Exception("Net not in repository: " + base_net) + if "is_master" not in net: + net["is_master"] = True + request.rundb.update_nn(net) + new_net = data["new_net"] + if new_net: + net = request.rundb.get_nn(new_net) + if not net: + return + if "first_test" not in net: + net["first_test"] = {"id": run_id, "date": datetime.datetime.utcnow()} + net["last_test"] = {"id": run_id, "date": datetime.datetime.utcnow()} request.rundb.update_nn(net) - new_net = data['new_net'] - if new_net: - net = request.rundb.get_nn(new_net) - if not net: - return - if 'first_test' not in net: - net['first_test'] = { 'id': run_id, 'date': datetime.datetime.utcnow() } - net['last_test'] = { 'id': run_id, 'date': datetime.datetime.utcnow() } - request.rundb.update_nn(net) - - -@view_config(route_name='tests_run', renderer='tests_run.mak', require_csrf=True) -def tests_run(request): - if not authenticated_userid(request): - request.session.flash('Please login') - next_page = '/tests/run' - if 'id' in request.params: - next_page += '?id={}'.format(request.params['id']) - return HTTPFound(location='{}?next={}'.format(request.route_url('login'), next_page)) - if request.method == 'POST': - try: - data = validate_form(request) - run_id = request.rundb.new_run(**data) - run = del_tasks(request.rundb.get_run(run_id)) - request.actiondb.new_run(authenticated_userid(request), run) - cached_flash(request, 'Submitted test to the queue!') - return HTTPFound(location='/tests/view/' + str(run_id)) - except Exception as e: - request.session.flash(str(e), 'error') - - run_args = {} - if 'id' in request.params: - run_args = request.rundb.get_run(request.params['id'])['args'] - username = authenticated_userid(request) - u = request.userdb.get_user(username) - return {'args': run_args, - 'is_rerun': len(run_args) > 0, - 'rescheduled_from': request.params['id'] if 'id' in request.params else None, - 'tests_repo': u.get('tests_repo', ''), - 'bench': get_master_bench()} +@view_config(route_name="tests_run", renderer="tests_run.mak", require_csrf=True) +def tests_run(request): + if not authenticated_userid(request): + request.session.flash("Please login") + next_page = "/tests/run" + if "id" in request.params: + next_page += "?id={}".format(request.params["id"]) + return HTTPFound( + location="{}?next={}".format(request.route_url("login"), next_page) + ) + if request.method == "POST": + try: + data = validate_form(request) + run_id = request.rundb.new_run(**data) + run = del_tasks(request.rundb.get_run(run_id)) + request.actiondb.new_run(authenticated_userid(request), run) + cached_flash(request, "Submitted test to the queue!") + return HTTPFound(location="/tests/view/" + str(run_id)) + except Exception as e: + request.session.flash(str(e), "error") + + run_args = {} + if "id" in request.params: + run_args = request.rundb.get_run(request.params["id"])["args"] + + username = authenticated_userid(request) + u = request.userdb.get_user(username) + + return { + "args": run_args, + "is_rerun": len(run_args) > 0, + "rescheduled_from": request.params["id"] if "id" in request.params else None, + "tests_repo": u.get("tests_repo", ""), + "bench": get_master_bench(), + } def can_modify_run(request, run): - return (run['args']['username'] == authenticated_userid(request) - or has_permission('approve_run', request.context, request)) + return run["args"]["username"] == authenticated_userid(request) or has_permission( + "approve_run", request.context, request + ) -@view_config(route_name='tests_modify', require_csrf=True, request_method='POST') +@view_config(route_name="tests_modify", require_csrf=True, request_method="POST") def tests_modify(request): - if not authenticated_userid(request): - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - if 'num-games' in request.POST: - run = request.rundb.get_run(request.POST['run']) - before = del_tasks(run) - - if not can_modify_run(request, run): - request.session.flash("Unable to modify another user's run!", 'error') - return HTTPFound(location=request.route_url('tests')) - - existing_games = 0 - for chunk in run['tasks']: - existing_games += chunk['num_games'] - if 'stats' in chunk: - stats = chunk['stats'] - total = stats['wins'] + stats['losses'] + stats['draws'] - if total < chunk['num_games']: - chunk['pending'] = True - - num_games = int(request.POST['num-games']) - if (num_games > run['args']['num_games'] - and 'sprt' not in run['args'] - and 'spsa' not in run['args']): - request.session.flash( - 'Unable to modify number of games in a fixed game test!', 'error') - return HTTPFound(location=request.route_url('tests')) - - max_games = 4000 * request.rundb.chunk_size - if num_games > max_games: - request.session.flash('Number of games must be <= ' + str(max_games), 'error') - return HTTPFound(location=request.route_url('tests')) - - if num_games > existing_games: - # Create new chunks for the games - new_chunks = request.rundb.generate_tasks(num_games - existing_games) - run['tasks'] += new_chunks - - run['finished'] = False - run['args']['num_games'] = num_games - run['args']['priority'] = int(request.POST['priority']) - run['args']['throughput'] = int(request.POST['throughput']) - run['args']['auto_purge'] = True if request.POST.get('auto_purge') else False - request.rundb.calc_itp(run) - request.rundb.buffer(run, True) - request.rundb.task_time = 0 - - after = del_tasks(run) - request.actiondb.modify_run(authenticated_userid(request), before, after) - - cached_flash(request, 'Run successfully modified!') - return HTTPFound(location=request.route_url('tests')) - - -@view_config(route_name='tests_stop', require_csrf=True, request_method='POST') + if not authenticated_userid(request): + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + if "num-games" in request.POST: + run = request.rundb.get_run(request.POST["run"]) + before = del_tasks(run) + + if not can_modify_run(request, run): + request.session.flash("Unable to modify another user's run!", "error") + return HTTPFound(location=request.route_url("tests")) + + existing_games = 0 + for chunk in run["tasks"]: + existing_games += chunk["num_games"] + if "stats" in chunk: + stats = chunk["stats"] + total = stats["wins"] + stats["losses"] + stats["draws"] + if total < chunk["num_games"]: + chunk["pending"] = True + + num_games = int(request.POST["num-games"]) + if ( + num_games > run["args"]["num_games"] + and "sprt" not in run["args"] + and "spsa" not in run["args"] + ): + request.session.flash( + "Unable to modify number of games in a fixed game test!", "error" + ) + return HTTPFound(location=request.route_url("tests")) + + max_games = 4000 * request.rundb.chunk_size + if num_games > max_games: + request.session.flash( + "Number of games must be <= " + str(max_games), "error" + ) + return HTTPFound(location=request.route_url("tests")) + + if num_games > existing_games: + # Create new chunks for the games + new_chunks = request.rundb.generate_tasks(num_games - existing_games) + run["tasks"] += new_chunks + + run["finished"] = False + run["args"]["num_games"] = num_games + run["args"]["priority"] = int(request.POST["priority"]) + run["args"]["throughput"] = int(request.POST["throughput"]) + run["args"]["auto_purge"] = True if request.POST.get("auto_purge") else False + request.rundb.calc_itp(run) + request.rundb.buffer(run, True) + request.rundb.task_time = 0 + + after = del_tasks(run) + request.actiondb.modify_run(authenticated_userid(request), before, after) + + cached_flash(request, "Run successfully modified!") + return HTTPFound(location=request.route_url("tests")) + + +@view_config(route_name="tests_stop", require_csrf=True, request_method="POST") def tests_stop(request): - if not authenticated_userid(request): - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - if 'run-id' in request.POST: - run = request.rundb.get_run(request.POST['run-id']) - if not can_modify_run(request, run): - request.session.flash('Unable to modify another users run!', 'error') - return HTTPFound(location=request.route_url('tests')) - - run['finished'] = True - request.rundb.stop_run(request.POST['run-id']) - run = del_tasks(run) - request.actiondb.stop_run(authenticated_userid(request), run) - cached_flash(request, 'Stopped run') - return HTTPFound(location=request.route_url('tests')) - - -@view_config(route_name='tests_approve', - require_csrf=True, request_method='POST') + if not authenticated_userid(request): + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + if "run-id" in request.POST: + run = request.rundb.get_run(request.POST["run-id"]) + if not can_modify_run(request, run): + request.session.flash("Unable to modify another users run!", "error") + return HTTPFound(location=request.route_url("tests")) + + run["finished"] = True + request.rundb.stop_run(request.POST["run-id"]) + run = del_tasks(run) + request.actiondb.stop_run(authenticated_userid(request), run) + cached_flash(request, "Stopped run") + return HTTPFound(location=request.route_url("tests")) + + +@view_config(route_name="tests_approve", require_csrf=True, request_method="POST") def tests_approve(request): - if not authenticated_userid(request): - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - if not has_permission('approve_run', request.context, request): - request.session.flash('Please login as approver') - return HTTPFound(location=request.route_url('login')) - username = authenticated_userid(request) - run_id = request.POST['run-id'] - if request.rundb.approve_run(run_id, username): - run = request.rundb.get_run(run_id) - run = del_tasks(run) - update_nets(request, run) - request.actiondb.approve_run(username, run) - cached_flash(request, 'Approved run') - else: - request.session.flash('Unable to approve run!', 'error') - return HTTPFound(location=request.route_url('tests')) + if not authenticated_userid(request): + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + if not has_permission("approve_run", request.context, request): + request.session.flash("Please login as approver") + return HTTPFound(location=request.route_url("login")) + username = authenticated_userid(request) + run_id = request.POST["run-id"] + if request.rundb.approve_run(run_id, username): + run = request.rundb.get_run(run_id) + run = del_tasks(run) + update_nets(request, run) + request.actiondb.approve_run(username, run) + cached_flash(request, "Approved run") + else: + request.session.flash("Unable to approve run!", "error") + return HTTPFound(location=request.route_url("tests")) -@view_config(route_name='tests_purge', require_csrf=True, request_method='POST') +@view_config(route_name="tests_purge", require_csrf=True, request_method="POST") def tests_purge(request): - if not has_permission('approve_run', request.context, request): - request.session.flash('Please login as approver') - return HTTPFound(location=request.route_url('login')) - username = authenticated_userid(request) + if not has_permission("approve_run", request.context, request): + request.session.flash("Please login as approver") + return HTTPFound(location=request.route_url("login")) + username = authenticated_userid(request) - run = request.rundb.get_run(request.POST['run-id']) - if not run['finished']: - request.session.flash('Can only purge completed run', 'error') - return HTTPFound(location=request.route_url('tests')) + run = request.rundb.get_run(request.POST["run-id"]) + if not run["finished"]: + request.session.flash("Can only purge completed run", "error") + return HTTPFound(location=request.route_url("tests")) - purged = request.rundb.purge_run(run) - if not purged: - request.session.flash('No bad workers!') - return HTTPFound(location=request.route_url('tests')) - - run = del_tasks(run) - request.actiondb.purge_run(username, run) - - cached_flash(request, 'Purged run') - return HTTPFound(location=request.route_url('tests')) - - -@view_config(route_name='tests_delete', require_csrf=True, request_method='POST') -def tests_delete(request): - if not authenticated_userid(request): - request.session.flash('Please login') - return HTTPFound(location=request.route_url('login')) - if 'run-id' in request.POST: - run = request.rundb.get_run(request.POST['run-id']) - if not can_modify_run(request, run): - request.session.flash('Unable to modify another users run!', 'error') - return HTTPFound(location=request.route_url('tests')) - - run['deleted'] = True - run['finished'] = True - for w in run['tasks']: - w['pending'] = False - request.rundb.buffer(run, True) - request.rundb.task_time = 0 + purged = request.rundb.purge_run(run) + if not purged: + request.session.flash("No bad workers!") + return HTTPFound(location=request.route_url("tests")) run = del_tasks(run) - request.actiondb.delete_run(authenticated_userid(request), run) + request.actiondb.purge_run(username, run) - cached_flash(request, 'Deleted run') - return HTTPFound(location=request.route_url('tests')) + cached_flash(request, "Purged run") + return HTTPFound(location=request.route_url("tests")) -@view_config(route_name='tests_stats', renderer='tests_stats.mak') +@view_config(route_name="tests_delete", require_csrf=True, request_method="POST") +def tests_delete(request): + if not authenticated_userid(request): + request.session.flash("Please login") + return HTTPFound(location=request.route_url("login")) + if "run-id" in request.POST: + run = request.rundb.get_run(request.POST["run-id"]) + if not can_modify_run(request, run): + request.session.flash("Unable to modify another users run!", "error") + return HTTPFound(location=request.route_url("tests")) + + run["deleted"] = True + run["finished"] = True + for w in run["tasks"]: + w["pending"] = False + request.rundb.buffer(run, True) + request.rundb.task_time = 0 + + run = del_tasks(run) + request.actiondb.delete_run(authenticated_userid(request), run) + + cached_flash(request, "Deleted run") + return HTTPFound(location=request.route_url("tests")) + + +@view_config(route_name="tests_stats", renderer="tests_stats.mak") def tests_stats(request): - run = request.rundb.get_run(request.matchdict['id']) - request.rundb.get_results(run) - return {'run': run} + run = request.rundb.get_run(request.matchdict["id"]) + request.rundb.get_results(run) + return {"run": run} -@view_config(route_name='tests_machines', renderer='machines_table.mak') +@view_config(route_name="tests_machines", renderer="machines_table.mak") def tests_machines(request): - machines = request.rundb.get_machines() - for machine in machines: - machine['last_updated'] = delta_date(machine['last_updated']) - return { - 'machines': machines - } + machines = request.rundb.get_machines() + for machine in machines: + machine["last_updated"] = delta_date(machine["last_updated"]) + return {"machines": machines} -@view_config(route_name='tests_view_spsa_history', renderer='json') +@view_config(route_name="tests_view_spsa_history", renderer="json") def tests_view_spsa_history(request): - run = request.rundb.get_run(request.matchdict['id']) - if 'spsa' not in run['args']: - return {} + run = request.rundb.get_run(request.matchdict["id"]) + if "spsa" not in run["args"]: + return {} - return run['args']['spsa'] + return run["args"]["spsa"] -@view_config(route_name='tests_view', renderer='tests_view.mak') +@view_config(route_name="tests_view", renderer="tests_view.mak") def tests_view(request): - run = request.rundb.get_run(request.matchdict['id']) - if run is None: - raise exception_response(404) - results = request.rundb.get_results(run) - run['results_info'] = format_results(results, run) - run_args = [('id', str(run['_id']), '')] - if run.get('rescheduled_from'): - run_args.append(('rescheduled_from', run['rescheduled_from'], '')) - - for name in ['new_tag', 'new_signature', 'new_options', 'resolved_new', - 'new_net', - 'base_tag', 'base_signature', 'base_options', 'resolved_base', - 'base_net', - 'sprt', 'num_games', 'spsa', 'tc', 'threads', 'book', - 'book_depth', 'auto_purge', 'priority', 'itp', 'username', - 'tests_repo', 'info']: - - if name not in run['args']: - continue - - value = run['args'][name] - url = '' - - if name == 'new_tag' and 'msg_new' in run['args']: - value += ' (' + run['args']['msg_new'][:50] + ')' - - if name == 'base_tag' and 'msg_base' in run['args']: - value += ' (' + run['args']['msg_base'][:50] + ')' - - if name == 'sprt' and value != '-': - value = 'elo0: %.2f alpha: %.2f elo1: %.2f beta: %.2f state: %s (%s)' % \ - (value['elo0'], value['alpha'], value['elo1'], value['beta'], - value.get('state', '-'), value.get('elo_model', 'BayesElo')) - - if name == 'spsa' and value != '-': - iter_local = value['iter'] + 1 # assume at least one completed, - # and avoid division by zero - A = value['A'] - alpha = value['alpha'] - gamma = value['gamma'] - summary = 'Iter: %d, A: %d, alpha %0.3f, gamma %0.3f, clipping %s, rounding %s' \ - % (iter_local, A, alpha, gamma, - value['clipping'] if 'clipping' in value else 'old', - value['rounding'] if 'rounding' in value else 'deterministic') - params = value['params'] - value = [summary] - for p in params: - value.append([ - p['name'], - '{:.2f}'.format(p['theta']), - int(p['start']), - int(p['min']), - int(p['max']), - '{:.3f}'.format(p['c'] / (iter_local ** gamma)), - '{:.3f}'.format(p['a'] / (A + iter_local) ** alpha) - ]) - if 'tests_repo' in run['args']: - if name == 'new_tag': - url = run['args']['tests_repo'] + '/commit/' + run['args']['resolved_new'] - elif name == 'base_tag': - url = run['args']['tests_repo'] + '/commit/' + run['args']['resolved_base'] - elif name == 'tests_repo': - url = value - - if name == 'spsa': - run_args.append(('spsa', value, '')) + run = request.rundb.get_run(request.matchdict["id"]) + if run is None: + raise exception_response(404) + results = request.rundb.get_results(run) + run["results_info"] = format_results(results, run) + run_args = [("id", str(run["_id"]), "")] + if run.get("rescheduled_from"): + run_args.append(("rescheduled_from", run["rescheduled_from"], "")) + + for name in [ + "new_tag", + "new_signature", + "new_options", + "resolved_new", + "new_net", + "base_tag", + "base_signature", + "base_options", + "resolved_base", + "base_net", + "sprt", + "num_games", + "spsa", + "tc", + "threads", + "book", + "book_depth", + "auto_purge", + "priority", + "itp", + "username", + "tests_repo", + "info", + ]: + + if name not in run["args"]: + continue + + value = run["args"][name] + url = "" + + if name == "new_tag" and "msg_new" in run["args"]: + value += " (" + run["args"]["msg_new"][:50] + ")" + + if name == "base_tag" and "msg_base" in run["args"]: + value += " (" + run["args"]["msg_base"][:50] + ")" + + if name == "sprt" and value != "-": + value = "elo0: %.2f alpha: %.2f elo1: %.2f beta: %.2f state: %s (%s)" % ( + value["elo0"], + value["alpha"], + value["elo1"], + value["beta"], + value.get("state", "-"), + value.get("elo_model", "BayesElo"), + ) + + if name == "spsa" and value != "-": + iter_local = value["iter"] + 1 # assume at least one completed, + # and avoid division by zero + A = value["A"] + alpha = value["alpha"] + gamma = value["gamma"] + summary = ( + "Iter: %d, A: %d, alpha %0.3f, gamma %0.3f, clipping %s, rounding %s" + % ( + iter_local, + A, + alpha, + gamma, + value["clipping"] if "clipping" in value else "old", + value["rounding"] if "rounding" in value else "deterministic", + ) + ) + params = value["params"] + value = [summary] + for p in params: + value.append( + [ + p["name"], + "{:.2f}".format(p["theta"]), + int(p["start"]), + int(p["min"]), + int(p["max"]), + "{:.3f}".format(p["c"] / (iter_local ** gamma)), + "{:.3f}".format(p["a"] / (A + iter_local) ** alpha), + ] + ) + if "tests_repo" in run["args"]: + if name == "new_tag": + url = ( + run["args"]["tests_repo"] + "/commit/" + run["args"]["resolved_new"] + ) + elif name == "base_tag": + url = ( + run["args"]["tests_repo"] + + "/commit/" + + run["args"]["resolved_base"] + ) + elif name == "tests_repo": + url = value + + if name == "spsa": + run_args.append(("spsa", value, "")) + else: + try: + strval = str(value) + except: + strval = value.encode("ascii", "replace") + if name not in ["new_tag", "base_tag"]: + strval = html.escape(strval) + run_args.append((name, strval, url)) + + active = 0 + cores = 0 + for task in run["tasks"]: + if task["active"]: + active += 1 + cores += task["worker_info"]["concurrency"] + last_updated = task.get("last_updated", datetime.datetime.min) + task["last_updated"] = last_updated + + if run["args"].get("sprt"): + page_title = "SPRT {} vs {}".format( + run["args"]["new_tag"], run["args"]["base_tag"] + ) + elif run["args"].get("spsa"): + page_title = "SPSA {}".format(run["args"]["new_tag"]) else: - try: - strval = str(value) - except: - strval = value.encode('ascii', 'replace') - if name not in ['new_tag', 'base_tag']: - strval = html.escape(strval) - run_args.append((name, strval, url)) - - active = 0 - cores = 0 - for task in run['tasks']: - if task['active']: - active += 1 - cores += task['worker_info']['concurrency'] - last_updated = task.get('last_updated', datetime.datetime.min) - task['last_updated'] = last_updated - - if run['args'].get('sprt'): - page_title = 'SPRT {} vs {}'.format(run['args']['new_tag'], run['args']['base_tag']) - elif run['args'].get('spsa'): - page_title = 'SPSA {}'.format(run['args']['new_tag']) - else: - page_title = '{} games - {} vs {}'.format( - run['args']['num_games'], - run['args']['new_tag'], - run['args']['base_tag'] - ) - return {'run': run, 'run_args': run_args, 'page_title': page_title, - 'approver': has_permission('approve_run', request.context, request), - 'chi2': calculate_residuals(run), - 'totals': '(%s active worker%s with %s core%s)' - % (active, ('s' if active != 1 else ''), - cores, ('s' if cores != 1 else ''))} + page_title = "{} games - {} vs {}".format( + run["args"]["num_games"], run["args"]["new_tag"], run["args"]["base_tag"] + ) + return { + "run": run, + "run_args": run_args, + "page_title": page_title, + "approver": has_permission("approve_run", request.context, request), + "chi2": calculate_residuals(run), + "totals": "(%s active worker%s with %s core%s)" + % (active, ("s" if active != 1 else ""), cores, ("s" if cores != 1 else "")), + } def get_paginated_finished_runs(request): - username = request.matchdict.get('username', '') - success_only = request.params.get('success_only', False) - yellow_only = request.params.get('yellow_only', False) - ltc_only = request.params.get('ltc_only', False) - - page_idx = max(0, int(request.params.get('page', 1)) - 1) - page_size = 25 - finished_runs, num_finished_runs = request.rundb.get_finished_runs( - username=username, success_only=success_only, - yellow_only=yellow_only, ltc_only=ltc_only, - skip=page_idx * page_size, limit=page_size) - - pages = [{'idx': 'Prev', 'url': '?page={}'.format(page_idx), - 'state': 'disabled' if page_idx == 0 else ''}] - for idx, _ in enumerate(range(0, num_finished_runs, page_size)): - if idx < 5 or abs(page_idx - idx) < 5 or idx > (num_finished_runs / page_size) - 5: - pages.append({'idx': idx + 1, 'url': '?page={}'.format(idx + 1), - 'state': 'active' if page_idx == idx else ''}) - elif pages[-1]['idx'] != '...': - pages.append({'idx': '...', 'url': '', 'state': 'disabled'}) - pages.append({'idx': 'Next', 'url': '?page={}'.format(page_idx + 2), - 'state': 'disabled' if page_idx + 1 == len(pages) - 1 else ''}) - - for page in pages: - if success_only: - page['url'] += '&success_only=1' - if yellow_only: - page['url'] += '&yellow_only=1' - if ltc_only: - page['url'] += '<c_only=1' - - failed_runs = [] - for run in finished_runs: - # Ensure finished runs have results_info - results = request.rundb.get_results(run) - if 'results_info' not in run: - run['results_info'] = format_results(results, run) + username = request.matchdict.get("username", "") + success_only = request.params.get("success_only", False) + yellow_only = request.params.get("yellow_only", False) + ltc_only = request.params.get("ltc_only", False) + + page_idx = max(0, int(request.params.get("page", 1)) - 1) + page_size = 25 + finished_runs, num_finished_runs = request.rundb.get_finished_runs( + username=username, + success_only=success_only, + yellow_only=yellow_only, + ltc_only=ltc_only, + skip=page_idx * page_size, + limit=page_size, + ) - # Look for failed runs - if 'failed' in run: - failed_runs.append(run) + pages = [ + { + "idx": "Prev", + "url": "?page={}".format(page_idx), + "state": "disabled" if page_idx == 0 else "", + } + ] + for idx, _ in enumerate(range(0, num_finished_runs, page_size)): + if ( + idx < 5 + or abs(page_idx - idx) < 5 + or idx > (num_finished_runs / page_size) - 5 + ): + pages.append( + { + "idx": idx + 1, + "url": "?page={}".format(idx + 1), + "state": "active" if page_idx == idx else "", + } + ) + elif pages[-1]["idx"] != "...": + pages.append({"idx": "...", "url": "", "state": "disabled"}) + pages.append( + { + "idx": "Next", + "url": "?page={}".format(page_idx + 2), + "state": "disabled" if page_idx + 1 == len(pages) - 1 else "", + } + ) - return { - 'finished_runs': finished_runs, - 'finished_runs_pages': pages, - 'num_finished_runs': num_finished_runs, - 'failed_runs': failed_runs, - 'page_idx': page_idx, - } + for page in pages: + if success_only: + page["url"] += "&success_only=1" + if yellow_only: + page["url"] += "&yellow_only=1" + if ltc_only: + page["url"] += "<c_only=1" + + failed_runs = [] + for run in finished_runs: + # Ensure finished runs have results_info + results = request.rundb.get_results(run) + if "results_info" not in run: + run["results_info"] = format_results(results, run) + + # Look for failed runs + if "failed" in run: + failed_runs.append(run) + + return { + "finished_runs": finished_runs, + "finished_runs_pages": pages, + "num_finished_runs": num_finished_runs, + "failed_runs": failed_runs, + "page_idx": page_idx, + } -@view_config(route_name='tests_finished', renderer='tests_finished.mak') +@view_config(route_name="tests_finished", renderer="tests_finished.mak") def tests_finished(request): - return get_paginated_finished_runs(request) + return get_paginated_finished_runs(request) -@view_config(route_name='tests_user', renderer='tests_user.mak') +@view_config(route_name="tests_user", renderer="tests_user.mak") def tests_user(request): - username = request.matchdict.get('username', '') - response = { - **get_paginated_finished_runs(request), - 'username': username - } - if int(request.params.get('page', 1)) == 1: - response['runs'] = request.rundb.aggregate_unfinished_runs(username)[0] - # page 2 and beyond only show finished test results - return response + username = request.matchdict.get("username", "") + response = {**get_paginated_finished_runs(request), "username": username} + if int(request.params.get("page", 1)) == 1: + response["runs"] = request.rundb.aggregate_unfinished_runs(username)[0] + # page 2 and beyond only show finished test results + return response def homepage_results(request): - # Calculate games_per_minute from current machines - games_per_minute = 0.0 - machines = request.rundb.get_machines() - for machine in machines: - machine['last_updated'] = delta_date(machine['last_updated']) - if machine['nps'] != 0: - games_per_minute += ( - (machine['nps'] / 1600000.0) - * (60.0 / estimate_game_duration(machine['run']['args']['tc'])) - * (int(machine['concurrency']) // machine['run']['args'].get('threads', 1))) - machines.reverse() - # Get updated results for unfinished runs + finished runs - (runs, pending_hours, cores, nps) = request.rundb.aggregate_unfinished_runs() - return { - **get_paginated_finished_runs(request), - 'runs': runs, - 'machines': machines, - 'pending_hours': '%.1f' % (pending_hours), - 'cores': cores, - 'nps': nps, - 'games_per_minute': int(games_per_minute), - } + # Calculate games_per_minute from current machines + games_per_minute = 0.0 + machines = request.rundb.get_machines() + for machine in machines: + machine["last_updated"] = delta_date(machine["last_updated"]) + if machine["nps"] != 0: + games_per_minute += ( + (machine["nps"] / 1600000.0) + * (60.0 / estimate_game_duration(machine["run"]["args"]["tc"])) + * ( + int(machine["concurrency"]) + // machine["run"]["args"].get("threads", 1) + ) + ) + machines.reverse() + # Get updated results for unfinished runs + finished runs + (runs, pending_hours, cores, nps) = request.rundb.aggregate_unfinished_runs() + return { + **get_paginated_finished_runs(request), + "runs": runs, + "machines": machines, + "pending_hours": "%.1f" % (pending_hours), + "cores": cores, + "nps": nps, + "games_per_minute": int(games_per_minute), + } # For caching the homepage tests output @@ -1028,35 +1153,36 @@ def homepage_results(request): # Guard against parallel builds of main page building = threading.Semaphore() -@view_config(route_name='tests', renderer='tests.mak') -def tests(request): - if int(request.params.get('page', 1)) > 1: - # page 2 and beyond only show finished test results - return get_paginated_finished_runs(request) - global last_tests, last_time - if time.time() - last_time > cache_time: - acquired = building.acquire(last_tests is None) - if not acquired: - # We have a current cache and another thread is rebuilding, - # so return the current cache - pass - elif time.time() - last_time < cache_time: - # Another thread has built the cache for us, so we are done - building.release() - else: - # Not cached, so calculate and fetch homepage results - try: - last_tests = homepage_results(request) - except Exception as e: - print('Overview exception: ' + str(e)) - if not last_tests: - raise e - finally: - last_time = time.time() - building.release() - return { - **last_tests, - 'machines_shown': request.cookies.get('machines_state') == 'Hide', - 'pending_shown': request.cookies.get('pending_state') == 'Hide' - } +@view_config(route_name="tests", renderer="tests.mak") +def tests(request): + if int(request.params.get("page", 1)) > 1: + # page 2 and beyond only show finished test results + return get_paginated_finished_runs(request) + + global last_tests, last_time + if time.time() - last_time > cache_time: + acquired = building.acquire(last_tests is None) + if not acquired: + # We have a current cache and another thread is rebuilding, + # so return the current cache + pass + elif time.time() - last_time < cache_time: + # Another thread has built the cache for us, so we are done + building.release() + else: + # Not cached, so calculate and fetch homepage results + try: + last_tests = homepage_results(request) + except Exception as e: + print("Overview exception: " + str(e)) + if not last_tests: + raise e + finally: + last_time = time.time() + building.release() + return { + **last_tests, + "machines_shown": request.cookies.get("machines_state") == "Hide", + "pending_shown": request.cookies.get("pending_state") == "Hide", + } diff --git a/fishtest/run_all_tests.py b/fishtest/run_all_tests.py index 6f9457c50..f69bb4495 100644 --- a/fishtest/run_all_tests.py +++ b/fishtest/run_all_tests.py @@ -1,7 +1,7 @@ import unittest -def server_test_suite(): - test_loader = unittest.TestLoader() - test_suite = test_loader.discover('tests', pattern='test_*.py') - return test_suite +def server_test_suite(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover("tests", pattern="test_*.py") + return test_suite diff --git a/fishtest/setup.py b/fishtest/setup.py index f1aaf85ad..79a1e3223 100644 --- a/fishtest/setup.py +++ b/fishtest/setup.py @@ -2,43 +2,44 @@ from setuptools import setup, find_packages -README = '' -CHANGES = '' +README = "" +CHANGES = "" requires = [ - 'pyramid', - 'pyramid_debugtoolbar', - 'pyramid_mako', - 'waitress', - 'pymongo', - 'numpy', - 'scipy', - 'requests', - 'awscli', - ] + "pyramid", + "pyramid_debugtoolbar", + "pyramid_mako", + "waitress", + "pymongo", + "numpy", + "scipy", + "requests", + "awscli", +] -setup(name='fishtest-server', - version='0.1', - description='fishtest-server', - long_description=README + '\n\n' + CHANGES, - classifiers=[ +setup( + name="fishtest-server", + version="0.1", + description="fishtest-server", + long_description=README + "\n\n" + CHANGES, + classifiers=[ "Programming Language :: Python", "Framework :: Pyramid", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", - ], - author='', - author_email='', - url='', - keywords='web pyramid pylons', - packages=find_packages(), - include_package_data=True, - zip_safe=False, - install_requires=requires, - tests_require=requires, - test_suite="run_all_tests.server_test_suite", - entry_points="""\ + ], + author="", + author_email="", + url="", + keywords="web pyramid pylons", + packages=find_packages(), + include_package_data=True, + zip_safe=False, + install_requires=requires, + tests_require=requires, + test_suite="run_all_tests.server_test_suite", + entry_points="""\ [paste.app_factory] main=fishtest:main """, - ) +) diff --git a/fishtest/tests/test_api.py b/fishtest/tests/test_api.py index 618a25e4e..2b1815611 100644 --- a/fishtest/tests/test_api.py +++ b/fishtest/tests/test_api.py @@ -11,458 +11,464 @@ class TestApi(unittest.TestCase): - - @classmethod - def setUpClass(self): - self.rundb = get_rundb() - - # Set up a run - num_tasks = 4 - num_games = num_tasks * self.rundb.chunk_size - run_id = self.rundb.new_run('master', 'master', num_games, '10+0.01', - 'book', 10, 1, '', '', - username='travis', tests_repo='travis', - start_time=datetime.datetime.utcnow()) - self.run_id = str(run_id) - run = self.rundb.get_run(self.run_id) - run['approved'] = True - - # Set up a task - self.task_id = 0 - for i, task in enumerate(run['tasks']): - if i is not self.task_id: - run['tasks'][i]['pending'] = False - - self.rundb.buffer(run, True) - - # Set up an API user (a worker) - self.username = 'JoeUserWorker' - self.password = 'secret' - self.remote_addr = '127.0.0.1' - self.concurrency = 7 - - self.worker_info = { - 'username': self.username, - 'password': self.password, - 'remote_addr': self.remote_addr, - 'concurrency': self.concurrency, - 'unique_key': 'unique key', - 'version': WORKER_VERSION - } - self.rundb.userdb.create_user(self.username, self.password, 'email@email.email') - user = self.rundb.userdb.get_user(self.username) - user['blocked'] = False - user['machine_limit'] = 50 - self.rundb.userdb.save_user(user) - - self.rundb.userdb.user_cache.insert_one({ - 'username': self.username, - 'cpu_hours': 0, - }) - self.rundb.userdb.flag_cache.insert_one({ - 'ip': self.remote_addr, - 'country_code': '??' - }) - - @classmethod - def tearDownClass(self): - self.rundb.runs.delete_one({ '_id': self.run_id }) - self.rundb.userdb.users.delete_many({ 'username': self.username }) - self.rundb.userdb.user_cache.delete_many({ 'username': self.username }) - self.rundb.userdb.flag_cache.delete_many({ 'ip': self.remote_addr }) - self.rundb.stop() - self.rundb.runs.drop() - - - def build_json_request(self, json_body): - return DummyRequest( - rundb=self.rundb, - userdb=self.rundb.userdb, - actiondb=self.rundb.actiondb, - remote_addr=self.remote_addr, - json_body=json_body - ) - - def invalid_password_request(self): - return self.build_json_request({ - 'username': self.username, - 'password': 'wrong password' - }) - - def correct_password_request(self, json_body={}): - return self.build_json_request({ - 'username': self.username, - 'password': self.password, - **json_body, - }) - - - def test_get_active_runs(self): - request = DummyRequest(rundb=self.rundb) - response = ApiView(request).active_runs() - self.assertTrue(self.run_id in response) - - - def test_get_run(self): - request = DummyRequest( - rundb=self.rundb, - matchdict={'id': self.run_id} - ) - response = ApiView(request).get_run() - self.assertEqual(self.run_id, response['_id']) - - - def test_get_elo(self): - request = DummyRequest( - rundb=self.rundb, - matchdict={'id': self.run_id} - ) - response = ApiView(request).get_elo() - self.assertTrue(not response) - - - def test_request_task(self): - with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).update_task() - self.assertTrue('error' in response) - - run = self.rundb.get_run(self.run_id) - self.assertEqual(run.get('cores'), None) - - run['tasks'][self.task_id] = { - 'num_games': self.rundb.chunk_size, - 'stats': { 'wins': 0, 'draws': 0, 'losses': 0, 'crashes': 0 }, - 'pending': True, - 'active': False, - } - self.rundb.buffer(run, True) - - request = self.correct_password_request({ 'worker_info': self.worker_info }) - response = ApiView(request).request_task() - self.assertEqual(self.run_id, response['run']['_id']) - self.assertEqual(self.task_id, response['task_id']) - - run = self.rundb.get_run(self.run_id) - self.assertEqual(run['cores'], self.concurrency) - task = run['tasks'][self.task_id] - self.assertTrue(task['pending']) - self.assertTrue(task['active']) - - - def test_update_task(self): - self.assertFalse(self.rundb.get_run(self.run_id)['results_stale']) - - # Request fails if username/password is invalid - with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).update_task() - self.assertTrue('error' in response) - - # Prepare a pending task that will be assigned to this worker - run = self.rundb.get_run(self.run_id) - run['tasks'][self.task_id] = { - 'num_games': self.rundb.chunk_size, - 'pending': True, - 'active': False - } - if run['args'].get('spsa'): - del run['args']['spsa'] - self.rundb.buffer(run, True) - - # Calling /api/request_task assigns this task to the worker - request = self.correct_password_request({ 'worker_info': self.worker_info }) - response = ApiView(request).request_task() - self.assertEqual(response['run']['_id'], str(run['_id'])) - - # Task is active after calling /api/update_task with the first set of results - request = self.correct_password_request({ - 'worker_info': self.worker_info, - 'run_id': self.run_id, - 'task_id': self.task_id, - 'stats': { 'wins': 2, 'draws': 0, 'losses': 0, 'crashes': 0 } - }) - response = ApiView(request).update_task() - self.assertTrue(response['task_alive']) - self.assertTrue(self.rundb.get_run(self.run_id)['results_stale']) - - # Task is still active - cs=self.rundb.chunk_size - w,d,l=cs/2-10,cs/2,0 - request.json_body['stats'] = { - 'wins': w, 'draws': d, 'losses': l, 'crashes': 0 - } - response = ApiView(request).update_task() - self.assertTrue(response['task_alive']) - self.assertTrue(self.rundb.get_run(self.run_id)['results_stale']) - - # Task is still active. Odd update. - request.json_body['stats'] = { - 'wins': w+1, 'draws': d, 'losses': 0, 'crashes': 0 - } - response = ApiView(request).update_task() - self.assertFalse(response['task_alive']) - - # Task_alive is a misnomer... - request.json_body['stats'] = { - 'wins': w+2, 'draws': d, 'losses': 0, 'crashes': 0 - } - response = ApiView(request).update_task() - self.assertTrue(response['task_alive']) - - # Go back in time - request.json_body['stats'] = { - 'wins': w, 'draws': d, 'losses': 0, 'crashes': 0 - } - response = ApiView(request).update_task() - self.assertFalse(response['task_alive']) - - # Task is finished when calling /api/update_task with results where the number of - # games played is the same as the number of games in the task - task_num_games = run['tasks'][self.task_id]['num_games'] - request.json_body['stats'] = { - 'wins': task_num_games, 'draws': 0, 'losses': 0, 'crashes': 0 - } - response = ApiView(request).update_task() - self.assertFalse(self.rundb.get_run(self.run_id)['results_stale']) - self.assertFalse(response['task_alive']) - run = self.rundb.get_run(self.run_id) - task = run['tasks'][self.task_id] - self.assertFalse(task['pending']) - self.assertFalse(task['active']) - - - def test_failed_task(self): - request = self.correct_password_request({ - 'run_id': self.run_id, - 'task_id': 0, - }) - response = ApiView(request).failed_task() - self.assertFalse(response['task_alive']) - - run = self.rundb.get_run(self.run_id) - run['tasks'][self.task_id]['active'] = True - run['tasks'][self.task_id]['worker_info'] = self.worker_info - self.rundb.buffer(run, True) - run = self.rundb.get_run(self.run_id) - self.assertTrue(run['tasks'][self.task_id]['active']) - - request = self.correct_password_request({ - 'run_id': self.run_id, - 'task_id': self.task_id, - }) - response = ApiView(request).failed_task() - self.assertTrue(not response) - self.assertFalse(run['tasks'][self.task_id]['active']) - - - def test_stop_run(self): - with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).stop_run() - self.assertTrue('error' in response) - - run = self.rundb.get_run(self.run_id) - self.assertFalse(run['finished']) - - request = self.correct_password_request({ 'run_id': self.run_id }) - response = ApiView(request).stop_run() - self.assertTrue(not response) - - self.rundb.userdb.user_cache.update_one({ 'username': self.username }, { - '$set': { - 'cpu_hours': 10000 - } - }) - user = self.rundb.userdb.user_cache.find_one({ 'username': self.username }) - self.assertTrue(user['cpu_hours'] == 10000) - - response = ApiView(request).stop_run() - self.assertTrue(not response) - - run = self.rundb.get_run(self.run_id) - self.assertTrue(run['finished']) - self.assertEqual(run['stop_reason'], 'API request') - - run['finished'] = False - self.rundb.buffer(run, True) - - - def test_upload_pgn(self): - pgn_text = '1. e4 e5 2. d4 d5' - request = self.correct_password_request({ - 'run_id': self.run_id, - 'task_id': self.task_id, - 'pgn': base64.b64encode(zlib.compress(pgn_text.encode('utf-8'))).decode() - }) - response = ApiView(request).upload_pgn() - self.assertTrue(not response) - - pgn_filename_prefix = '{}-{}'.format(self.run_id, self.task_id) - pgn = self.rundb.get_pgn(pgn_filename_prefix) - self.assertEqual(pgn, pgn_text) - self.rundb.pgndb.delete_one({ 'run_id': pgn_filename_prefix }) - - - def test_request_spsa(self): - request = self.correct_password_request({ - 'run_id': self.run_id, - 'task_id': 0, - }) - response = ApiView(request).request_spsa() - self.assertFalse(response['task_alive']) - - run = self.rundb.get_run(self.run_id) - run['args']['spsa'] = { - 'iter': 1, - 'num_iter': 10, - 'alpha': 1, - 'gamma': 1, - 'A': 1, - 'params': [{ - 'name': 'param name', - 'a': 1, - 'c': 1, - 'theta': 1, - 'min': 0, - 'max': 100, - }] - } - run['tasks'][self.task_id]['pending'] = True - run['tasks'][self.task_id]['active'] = True - self.rundb.buffer(run, True) - request = self.correct_password_request({ - 'run_id': self.run_id, - 'task_id': self.task_id, - }) - response = ApiView(request).request_spsa() - self.assertTrue(response['task_alive']) - self.assertTrue(response['w_params'] is not None) - self.assertTrue(response['b_params'] is not None) - - - def test_request_version(self): - with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).request_version() - self.assertTrue('error' in response) - - response = ApiView(self.correct_password_request()).request_version() - self.assertEqual(WORKER_VERSION, response['version']) + @classmethod + def setUpClass(self): + self.rundb = get_rundb() + + # Set up a run + num_tasks = 4 + num_games = num_tasks * self.rundb.chunk_size + run_id = self.rundb.new_run( + "master", + "master", + num_games, + "10+0.01", + "book", + 10, + 1, + "", + "", + username="travis", + tests_repo="travis", + start_time=datetime.datetime.utcnow(), + ) + self.run_id = str(run_id) + run = self.rundb.get_run(self.run_id) + run["approved"] = True + + # Set up a task + self.task_id = 0 + for i, task in enumerate(run["tasks"]): + if i is not self.task_id: + run["tasks"][i]["pending"] = False + + self.rundb.buffer(run, True) + + # Set up an API user (a worker) + self.username = "JoeUserWorker" + self.password = "secret" + self.remote_addr = "127.0.0.1" + self.concurrency = 7 + + self.worker_info = { + "username": self.username, + "password": self.password, + "remote_addr": self.remote_addr, + "concurrency": self.concurrency, + "unique_key": "unique key", + "version": WORKER_VERSION, + } + self.rundb.userdb.create_user(self.username, self.password, "email@email.email") + user = self.rundb.userdb.get_user(self.username) + user["blocked"] = False + user["machine_limit"] = 50 + self.rundb.userdb.save_user(user) + + self.rundb.userdb.user_cache.insert_one( + {"username": self.username, "cpu_hours": 0} + ) + self.rundb.userdb.flag_cache.insert_one( + {"ip": self.remote_addr, "country_code": "??"} + ) + + @classmethod + def tearDownClass(self): + self.rundb.runs.delete_one({"_id": self.run_id}) + self.rundb.userdb.users.delete_many({"username": self.username}) + self.rundb.userdb.user_cache.delete_many({"username": self.username}) + self.rundb.userdb.flag_cache.delete_many({"ip": self.remote_addr}) + self.rundb.stop() + self.rundb.runs.drop() + + def build_json_request(self, json_body): + return DummyRequest( + rundb=self.rundb, + userdb=self.rundb.userdb, + actiondb=self.rundb.actiondb, + remote_addr=self.remote_addr, + json_body=json_body, + ) + + def invalid_password_request(self): + return self.build_json_request( + {"username": self.username, "password": "wrong password"} + ) + + def correct_password_request(self, json_body={}): + return self.build_json_request( + {"username": self.username, "password": self.password, **json_body} + ) + + def test_get_active_runs(self): + request = DummyRequest(rundb=self.rundb) + response = ApiView(request).active_runs() + self.assertTrue(self.run_id in response) + + def test_get_run(self): + request = DummyRequest(rundb=self.rundb, matchdict={"id": self.run_id}) + response = ApiView(request).get_run() + self.assertEqual(self.run_id, response["_id"]) + + def test_get_elo(self): + request = DummyRequest(rundb=self.rundb, matchdict={"id": self.run_id}) + response = ApiView(request).get_elo() + self.assertTrue(not response) + + def test_request_task(self): + with self.assertRaises(HTTPUnauthorized): + response = ApiView(self.invalid_password_request()).update_task() + self.assertTrue("error" in response) + + run = self.rundb.get_run(self.run_id) + self.assertEqual(run.get("cores"), None) + + run["tasks"][self.task_id] = { + "num_games": self.rundb.chunk_size, + "stats": {"wins": 0, "draws": 0, "losses": 0, "crashes": 0}, + "pending": True, + "active": False, + } + self.rundb.buffer(run, True) + + request = self.correct_password_request({"worker_info": self.worker_info}) + response = ApiView(request).request_task() + self.assertEqual(self.run_id, response["run"]["_id"]) + self.assertEqual(self.task_id, response["task_id"]) + + run = self.rundb.get_run(self.run_id) + self.assertEqual(run["cores"], self.concurrency) + task = run["tasks"][self.task_id] + self.assertTrue(task["pending"]) + self.assertTrue(task["active"]) + + def test_update_task(self): + self.assertFalse(self.rundb.get_run(self.run_id)["results_stale"]) + + # Request fails if username/password is invalid + with self.assertRaises(HTTPUnauthorized): + response = ApiView(self.invalid_password_request()).update_task() + self.assertTrue("error" in response) + + # Prepare a pending task that will be assigned to this worker + run = self.rundb.get_run(self.run_id) + run["tasks"][self.task_id] = { + "num_games": self.rundb.chunk_size, + "pending": True, + "active": False, + } + if run["args"].get("spsa"): + del run["args"]["spsa"] + self.rundb.buffer(run, True) + + # Calling /api/request_task assigns this task to the worker + request = self.correct_password_request({"worker_info": self.worker_info}) + response = ApiView(request).request_task() + self.assertEqual(response["run"]["_id"], str(run["_id"])) + + # Task is active after calling /api/update_task with the first set of results + request = self.correct_password_request( + { + "worker_info": self.worker_info, + "run_id": self.run_id, + "task_id": self.task_id, + "stats": {"wins": 2, "draws": 0, "losses": 0, "crashes": 0}, + } + ) + response = ApiView(request).update_task() + self.assertTrue(response["task_alive"]) + self.assertTrue(self.rundb.get_run(self.run_id)["results_stale"]) + + # Task is still active + cs = self.rundb.chunk_size + w, d, l = cs / 2 - 10, cs / 2, 0 + request.json_body["stats"] = {"wins": w, "draws": d, "losses": l, "crashes": 0} + response = ApiView(request).update_task() + self.assertTrue(response["task_alive"]) + self.assertTrue(self.rundb.get_run(self.run_id)["results_stale"]) + + # Task is still active. Odd update. + request.json_body["stats"] = { + "wins": w + 1, + "draws": d, + "losses": 0, + "crashes": 0, + } + response = ApiView(request).update_task() + self.assertFalse(response["task_alive"]) + + # Task_alive is a misnomer... + request.json_body["stats"] = { + "wins": w + 2, + "draws": d, + "losses": 0, + "crashes": 0, + } + response = ApiView(request).update_task() + self.assertTrue(response["task_alive"]) + + # Go back in time + request.json_body["stats"] = {"wins": w, "draws": d, "losses": 0, "crashes": 0} + response = ApiView(request).update_task() + self.assertFalse(response["task_alive"]) + + # Task is finished when calling /api/update_task with results where the number of + # games played is the same as the number of games in the task + task_num_games = run["tasks"][self.task_id]["num_games"] + request.json_body["stats"] = { + "wins": task_num_games, + "draws": 0, + "losses": 0, + "crashes": 0, + } + response = ApiView(request).update_task() + self.assertFalse(self.rundb.get_run(self.run_id)["results_stale"]) + self.assertFalse(response["task_alive"]) + run = self.rundb.get_run(self.run_id) + task = run["tasks"][self.task_id] + self.assertFalse(task["pending"]) + self.assertFalse(task["active"]) + + def test_failed_task(self): + request = self.correct_password_request({"run_id": self.run_id, "task_id": 0}) + response = ApiView(request).failed_task() + self.assertFalse(response["task_alive"]) + + run = self.rundb.get_run(self.run_id) + run["tasks"][self.task_id]["active"] = True + run["tasks"][self.task_id]["worker_info"] = self.worker_info + self.rundb.buffer(run, True) + run = self.rundb.get_run(self.run_id) + self.assertTrue(run["tasks"][self.task_id]["active"]) + + request = self.correct_password_request( + {"run_id": self.run_id, "task_id": self.task_id} + ) + response = ApiView(request).failed_task() + self.assertTrue(not response) + self.assertFalse(run["tasks"][self.task_id]["active"]) + + def test_stop_run(self): + with self.assertRaises(HTTPUnauthorized): + response = ApiView(self.invalid_password_request()).stop_run() + self.assertTrue("error" in response) + + run = self.rundb.get_run(self.run_id) + self.assertFalse(run["finished"]) + + request = self.correct_password_request({"run_id": self.run_id}) + response = ApiView(request).stop_run() + self.assertTrue(not response) + + self.rundb.userdb.user_cache.update_one( + {"username": self.username}, {"$set": {"cpu_hours": 10000}} + ) + user = self.rundb.userdb.user_cache.find_one({"username": self.username}) + self.assertTrue(user["cpu_hours"] == 10000) + + response = ApiView(request).stop_run() + self.assertTrue(not response) + + run = self.rundb.get_run(self.run_id) + self.assertTrue(run["finished"]) + self.assertEqual(run["stop_reason"], "API request") + + run["finished"] = False + self.rundb.buffer(run, True) + + def test_upload_pgn(self): + pgn_text = "1. e4 e5 2. d4 d5" + request = self.correct_password_request( + { + "run_id": self.run_id, + "task_id": self.task_id, + "pgn": base64.b64encode( + zlib.compress(pgn_text.encode("utf-8")) + ).decode(), + } + ) + response = ApiView(request).upload_pgn() + self.assertTrue(not response) + + pgn_filename_prefix = "{}-{}".format(self.run_id, self.task_id) + pgn = self.rundb.get_pgn(pgn_filename_prefix) + self.assertEqual(pgn, pgn_text) + self.rundb.pgndb.delete_one({"run_id": pgn_filename_prefix}) + + def test_request_spsa(self): + request = self.correct_password_request({"run_id": self.run_id, "task_id": 0}) + response = ApiView(request).request_spsa() + self.assertFalse(response["task_alive"]) + + run = self.rundb.get_run(self.run_id) + run["args"]["spsa"] = { + "iter": 1, + "num_iter": 10, + "alpha": 1, + "gamma": 1, + "A": 1, + "params": [ + {"name": "param name", "a": 1, "c": 1, "theta": 1, "min": 0, "max": 100} + ], + } + run["tasks"][self.task_id]["pending"] = True + run["tasks"][self.task_id]["active"] = True + self.rundb.buffer(run, True) + request = self.correct_password_request( + {"run_id": self.run_id, "task_id": self.task_id} + ) + response = ApiView(request).request_spsa() + self.assertTrue(response["task_alive"]) + self.assertTrue(response["w_params"] is not None) + self.assertTrue(response["b_params"] is not None) + + def test_request_version(self): + with self.assertRaises(HTTPUnauthorized): + response = ApiView(self.invalid_password_request()).request_version() + self.assertTrue("error" in response) + + response = ApiView(self.correct_password_request()).request_version() + self.assertEqual(WORKER_VERSION, response["version"]) class TestRunFinished(unittest.TestCase): - - @classmethod - def setUpClass(self): - self.rundb = get_rundb() - - # Set up a run with 2 tasks - num_games = 2 * self.rundb.chunk_size - run_id = self.rundb.new_run('master', 'master', num_games, '10+0.01', - 'book', 10, 1, '', '', - username='travis', tests_repo='travis', - start_time=datetime.datetime.utcnow()) - self.run_id = str(run_id) - run = self.rundb.get_run(self.run_id) - run['approved'] = True - - # Set up a user - self.username = 'JoeUserWorker2' - self.password = 'secret' - self.remote_addr = '127.0.0.1' - self.concurrency = 7 - self.worker_info = { - 'username': self.username, - 'password': self.password, - 'remote_addr': self.remote_addr, - 'concurrency': self.concurrency, - 'unique_key': 'unique key', - 'version': WORKER_VERSION - } - - self.rundb.userdb.create_user(self.username, self.password, 'email@email.email') - user = self.rundb.userdb.get_user(self.username) - user['blocked'] = False - user['machine_limit'] = 50 - self.rundb.userdb.save_user(user) - self.rundb.userdb.user_cache.insert_one({ - 'username': self.username, - 'cpu_hours': 0, - }) - self.rundb.userdb.flag_cache.insert_one({ - 'ip': self.remote_addr, - 'country_code': '??' - }) - - @classmethod - def tearDownClass(self): - self.rundb.userdb.users.delete_many({ 'username': self.username }) - self.rundb.userdb.user_cache.delete_many({ 'username': self.username }) - self.rundb.stop() - self.rundb.runs.drop() - - def build_json_request(self, json_body): - return DummyRequest( - rundb=self.rundb, - userdb=self.rundb.userdb, - actiondb=self.rundb.actiondb, - remote_addr=self.remote_addr, - json_body=json_body - ) - - def correct_password_request(self, json_body={}): - return self.build_json_request({ - 'username': self.username, - 'password': self.password, - **json_body, - }) - - - def test_auto_purge_runs(self): - run = self.rundb.get_run(self.run_id) - - # Request task 1 of 2 - request = self.correct_password_request({ 'worker_info': self.worker_info }) - response = ApiView(request).request_task() - self.assertEqual(response['run']['_id'], str(run['_id'])) - self.assertEqual(response['task_id'], 0) - - # Request task 2 of 2 - request = self.correct_password_request({ 'worker_info': self.worker_info }) - response = ApiView(request).request_task() - self.assertEqual(response['run']['_id'], str(run['_id'])) - self.assertEqual(response['task_id'], 1) - - n_wins = self.rundb.chunk_size / 5 - n_losses = self.rundb.chunk_size / 5 - n_draws = self.rundb.chunk_size * 3/5 - - # Finish task 1 of 2 - request = self.correct_password_request({ - 'worker_info': self.worker_info, - 'run_id': self.run_id, - 'task_id': 0, - 'stats': { 'wins': n_wins, 'draws': n_draws, 'losses': n_losses, 'crashes': 0 } - }) - response = ApiView(request).update_task() - self.assertFalse(response['task_alive']) - run = self.rundb.get_run(self.run_id) - self.assertFalse(run['finished']) - - # Finish task 2 of 2 - request = self.correct_password_request({ - 'worker_info': self.worker_info, - 'run_id': self.run_id, - 'task_id': 1, - 'stats': { 'wins': n_wins, 'draws': n_draws, 'losses': n_losses, 'crashes': 0 } - }) - response = ApiView(request).update_task() - self.assertFalse(response['task_alive']) - - # The run should be marked as finished after the last task completes - run = self.rundb.get_run(self.run_id) - self.assertTrue(run['finished']) - self.assertFalse(run['results_stale']) - self.assertTrue(all([not t['pending'] and not t['active'] for t in run['tasks']])) - self.assertTrue('Total: {}'.format(self.rundb.chunk_size * 2) in run['results_info']['info'][1]) - - -if __name__ == '__main__': - unittest.main() + @classmethod + def setUpClass(self): + self.rundb = get_rundb() + + # Set up a run with 2 tasks + num_games = 2 * self.rundb.chunk_size + run_id = self.rundb.new_run( + "master", + "master", + num_games, + "10+0.01", + "book", + 10, + 1, + "", + "", + username="travis", + tests_repo="travis", + start_time=datetime.datetime.utcnow(), + ) + self.run_id = str(run_id) + run = self.rundb.get_run(self.run_id) + run["approved"] = True + + # Set up a user + self.username = "JoeUserWorker2" + self.password = "secret" + self.remote_addr = "127.0.0.1" + self.concurrency = 7 + self.worker_info = { + "username": self.username, + "password": self.password, + "remote_addr": self.remote_addr, + "concurrency": self.concurrency, + "unique_key": "unique key", + "version": WORKER_VERSION, + } + + self.rundb.userdb.create_user(self.username, self.password, "email@email.email") + user = self.rundb.userdb.get_user(self.username) + user["blocked"] = False + user["machine_limit"] = 50 + self.rundb.userdb.save_user(user) + self.rundb.userdb.user_cache.insert_one( + {"username": self.username, "cpu_hours": 0} + ) + self.rundb.userdb.flag_cache.insert_one( + {"ip": self.remote_addr, "country_code": "??"} + ) + + @classmethod + def tearDownClass(self): + self.rundb.userdb.users.delete_many({"username": self.username}) + self.rundb.userdb.user_cache.delete_many({"username": self.username}) + self.rundb.stop() + self.rundb.runs.drop() + + def build_json_request(self, json_body): + return DummyRequest( + rundb=self.rundb, + userdb=self.rundb.userdb, + actiondb=self.rundb.actiondb, + remote_addr=self.remote_addr, + json_body=json_body, + ) + + def correct_password_request(self, json_body={}): + return self.build_json_request( + {"username": self.username, "password": self.password, **json_body} + ) + + def test_auto_purge_runs(self): + run = self.rundb.get_run(self.run_id) + + # Request task 1 of 2 + request = self.correct_password_request({"worker_info": self.worker_info}) + response = ApiView(request).request_task() + self.assertEqual(response["run"]["_id"], str(run["_id"])) + self.assertEqual(response["task_id"], 0) + + # Request task 2 of 2 + request = self.correct_password_request({"worker_info": self.worker_info}) + response = ApiView(request).request_task() + self.assertEqual(response["run"]["_id"], str(run["_id"])) + self.assertEqual(response["task_id"], 1) + + n_wins = self.rundb.chunk_size / 5 + n_losses = self.rundb.chunk_size / 5 + n_draws = self.rundb.chunk_size * 3 / 5 + + # Finish task 1 of 2 + request = self.correct_password_request( + { + "worker_info": self.worker_info, + "run_id": self.run_id, + "task_id": 0, + "stats": { + "wins": n_wins, + "draws": n_draws, + "losses": n_losses, + "crashes": 0, + }, + } + ) + response = ApiView(request).update_task() + self.assertFalse(response["task_alive"]) + run = self.rundb.get_run(self.run_id) + self.assertFalse(run["finished"]) + + # Finish task 2 of 2 + request = self.correct_password_request( + { + "worker_info": self.worker_info, + "run_id": self.run_id, + "task_id": 1, + "stats": { + "wins": n_wins, + "draws": n_draws, + "losses": n_losses, + "crashes": 0, + }, + } + ) + response = ApiView(request).update_task() + self.assertFalse(response["task_alive"]) + + # The run should be marked as finished after the last task completes + run = self.rundb.get_run(self.run_id) + self.assertTrue(run["finished"]) + self.assertFalse(run["results_stale"]) + self.assertTrue( + all([not t["pending"] and not t["active"] for t in run["tasks"]]) + ) + self.assertTrue( + "Total: {}".format(self.rundb.chunk_size * 2) + in run["results_info"]["info"][1] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fishtest/tests/test_run.py b/fishtest/tests/test_run.py index 562a757ea..5fcee3861 100644 --- a/fishtest/tests/test_run.py +++ b/fishtest/tests/test_run.py @@ -3,11 +3,11 @@ from fishtest.views import get_master_bench -class CreateRunTest(unittest.TestCase): - def test_10_get_bench(self): - self.assertTrue(re.match('[0-9]{7}|None', str(get_master_bench()))) +class CreateRunTest(unittest.TestCase): + def test_10_get_bench(self): + self.assertTrue(re.match("[0-9]{7}|None", str(get_master_bench()))) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/fishtest/tests/test_rundb.py b/fishtest/tests/test_rundb.py index 04521ff60..7169a0e81 100644 --- a/fishtest/tests/test_rundb.py +++ b/fishtest/tests/test_rundb.py @@ -7,84 +7,145 @@ run_id = None -class CreateRunDBTest(unittest.TestCase): - def setUp(self): - self.rundb = util.get_rundb() - self.rundb.runs.create_index( - [('last_updated', DESCENDING), ('tc_base', DESCENDING)], - name='finished_ltc_runs', - partialFilterExpression={ 'finished': True, 'tc_base': { '$gte': 40 } } - ) +class CreateRunDBTest(unittest.TestCase): + def setUp(self): + self.rundb = util.get_rundb() + self.rundb.runs.create_index( + [("last_updated", DESCENDING), ("tc_base", DESCENDING)], + name="finished_ltc_runs", + partialFilterExpression={"finished": True, "tc_base": {"$gte": 40}}, + ) - def tearDown(self): - self.rundb.runs.delete_many({ 'args.username': 'travis' }) - # Shutdown flush thread: - self.rundb.stop() + def tearDown(self): + self.rundb.runs.delete_many({"args.username": "travis"}) + # Shutdown flush thread: + self.rundb.stop() - def test_10_create_run(self): - global run_id - # STC - run_id_stc = self.rundb.new_run('master', 'master', 100000, '10+0.01', 'book', 10, 1, '', '', - username='travis', tests_repo='travis', - start_time=datetime.datetime.utcnow()) - run = self.rundb.get_run(run_id_stc) - run['finished'] = True - self.rundb.buffer(run, True) + def test_10_create_run(self): + global run_id + # STC + run_id_stc = self.rundb.new_run( + "master", + "master", + 100000, + "10+0.01", + "book", + 10, + 1, + "", + "", + username="travis", + tests_repo="travis", + start_time=datetime.datetime.utcnow(), + ) + run = self.rundb.get_run(run_id_stc) + run["finished"] = True + self.rundb.buffer(run, True) - # LTC - run_id = self.rundb.new_run('master', 'master', 100000, '150+0.01', 'book', 10, 1, '', '', - username='travis', tests_repo='travis', - start_time=datetime.datetime.utcnow()) - print(' ') - print(run_id) - run = self.rundb.get_run(run_id) - print(run['tasks'][0]) - self.assertFalse(run['tasks'][0][u'active']) - run['tasks'][0][u'active'] = True - run['tasks'][0][u'worker_info'] = { - 'username': 'worker1', 'unique_key': 'travis', 'concurrency': 1} - run['cores'] = 1 + # LTC + run_id = self.rundb.new_run( + "master", + "master", + 100000, + "150+0.01", + "book", + 10, + 1, + "", + "", + username="travis", + tests_repo="travis", + start_time=datetime.datetime.utcnow(), + ) + print(" ") + print(run_id) + run = self.rundb.get_run(run_id) + print(run["tasks"][0]) + self.assertFalse(run["tasks"][0][u"active"]) + run["tasks"][0][u"active"] = True + run["tasks"][0][u"worker_info"] = { + "username": "worker1", + "unique_key": "travis", + "concurrency": 1, + } + run["cores"] = 1 - for run in self.rundb.get_unfinished_runs(): - if run['args']['username'] == 'travis': - print(run['args']) + for run in self.rundb.get_unfinished_runs(): + if run["args"]["username"] == "travis": + print(run["args"]) - def test_20_update_task(self): - run = self.rundb.update_task(run_id, 0, {'wins': 1, 'losses': 1, 'draws': self.rundb.chunk_size-3, - 'crashes': 0, 'time_losses': 0}, - 1000000, '?', '', 'worker2') - self.assertEqual(run, {'task_alive': False}) - run = self.rundb.update_task(run_id, 0, {'wins': 1, 'losses': 1, 'draws': self.rundb.chunk_size-4, - 'crashes': 0, 'time_losses': 0}, - 1000000, '?', '', 'worker1') - self.assertEqual(run, {'task_alive': True}) - run = self.rundb.update_task(run_id, 0, {'wins': 1, 'losses': 1, 'draws': self.rundb.chunk_size-2, - 'crashes': 0, 'time_losses': 0}, - 1000000, '?', '', 'worker1') - self.assertEqual(run, {'task_alive': False}) + def test_20_update_task(self): + run = self.rundb.update_task( + run_id, + 0, + { + "wins": 1, + "losses": 1, + "draws": self.rundb.chunk_size - 3, + "crashes": 0, + "time_losses": 0, + }, + 1000000, + "?", + "", + "worker2", + ) + self.assertEqual(run, {"task_alive": False}) + run = self.rundb.update_task( + run_id, + 0, + { + "wins": 1, + "losses": 1, + "draws": self.rundb.chunk_size - 4, + "crashes": 0, + "time_losses": 0, + }, + 1000000, + "?", + "", + "worker1", + ) + self.assertEqual(run, {"task_alive": True}) + run = self.rundb.update_task( + run_id, + 0, + { + "wins": 1, + "losses": 1, + "draws": self.rundb.chunk_size - 2, + "crashes": 0, + "time_losses": 0, + }, + 1000000, + "?", + "", + "worker1", + ) + self.assertEqual(run, {"task_alive": False}) - def test_30_finish(self): - print('run_id: {}'.format(run_id)) - run = self.rundb.get_run(run_id) - run['finished'] = True - self.rundb.buffer(run, True) + def test_30_finish(self): + print("run_id: {}".format(run_id)) + run = self.rundb.get_run(run_id) + run["finished"] = True + self.rundb.buffer(run, True) - def test_40_list_LTC(self): - finished_runs = self.rundb.get_finished_runs(limit=3, ltc_only=True)[0] - for run in finished_runs: - print(run['args']['tc']) + def test_40_list_LTC(self): + finished_runs = self.rundb.get_finished_runs(limit=3, ltc_only=True)[0] + for run in finished_runs: + print(run["args"]["tc"]) - def test_90_delete_runs(self): - for run in self.rundb.runs.find(): - if run['args']['username'] == 'travis' and not 'deleted' in run: - print('del ') - run['deleted'] = True - run['finished'] = True - for w in run['tasks']: - w['pending'] = False - self.rundb.buffer(run, True) + def test_90_delete_runs(self): + for run in self.rundb.runs.find(): + if run["args"]["username"] == "travis" and not "deleted" in run: + print("del ") + run["deleted"] = True + run["finished"] = True + for w in run["tasks"]: + w["pending"] = False + self.rundb.buffer(run, True) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/fishtest/tests/test_users.py b/fishtest/tests/test_users.py index 1cf4ca254..4c8029980 100644 --- a/fishtest/tests/test_users.py +++ b/fishtest/tests/test_users.py @@ -8,114 +8,119 @@ import util -class Create10UsersTest(unittest.TestCase): - def setUp(self): - self.rundb = util.get_rundb() - self.config = testing.setUp() - self.config.add_route('login', '/login') - self.config.add_route('signup', '/signup') - - def tearDown(self): - self.rundb.userdb.users.delete_many({ 'username': 'JoeUser' }) - self.rundb.userdb.user_cache.delete_many({ 'username': 'JoeUser' }) - self.rundb.stop() - testing.tearDown() - - def test_create_user(self): - request = testing.DummyRequest( - userdb=self.rundb.userdb, - method='POST', - remote_addr="127.0.0.1", - params={ - 'username': 'JoeUser', - 'password': 'secret', - 'password2': 'secret', - 'email': 'joe@user.net', - } - ) - response = signup(request) - self.assertTrue('The resource was found at', response) +class Create10UsersTest(unittest.TestCase): + def setUp(self): + self.rundb = util.get_rundb() + self.config = testing.setUp() + self.config.add_route("login", "/login") + self.config.add_route("signup", "/signup") + + def tearDown(self): + self.rundb.userdb.users.delete_many({"username": "JoeUser"}) + self.rundb.userdb.user_cache.delete_many({"username": "JoeUser"}) + self.rundb.stop() + testing.tearDown() + + def test_create_user(self): + request = testing.DummyRequest( + userdb=self.rundb.userdb, + method="POST", + remote_addr="127.0.0.1", + params={ + "username": "JoeUser", + "password": "secret", + "password2": "secret", + "email": "joe@user.net", + }, + ) + response = signup(request) + self.assertTrue("The resource was found at", response) class Create50LoginTest(unittest.TestCase): - - def setUp(self): - self.rundb = util.get_rundb() - self.rundb.userdb.create_user('JoeUser', 'secret', 'email@email.email') - self.config = testing.setUp() - self.config.add_route('login', '/login') - - def tearDown(self): - self.rundb.userdb.users.delete_many({ 'username': 'JoeUser' }) - self.rundb.userdb.user_cache.delete_many({ 'username': 'JoeUser' }) - self.rundb.stop() - testing.tearDown() - - def test_login(self): - request = testing.DummyRequest( - userdb=self.rundb.userdb, - method='POST', - params = { - 'username': 'JoeUser', - 'password': 'badsecret' - } - ) - response = login(request) - self.assertTrue('Invalid password' in request.session.pop_flash('error')) - - # Correct password, but still blocked from logging in - request.params['password'] = 'secret' - login(request) - self.assertTrue('Blocked' in request.session.pop_flash('error')) - - # Unblock, then user can log in successfully - user = self.rundb.userdb.get_user('JoeUser') - user['blocked'] = False - self.rundb.userdb.save_user(user) - response = login(request) - self.assertEqual(response.code, 302) - self.assertTrue('The resource was found at' in str(response)) + def setUp(self): + self.rundb = util.get_rundb() + self.rundb.userdb.create_user("JoeUser", "secret", "email@email.email") + self.config = testing.setUp() + self.config.add_route("login", "/login") + + def tearDown(self): + self.rundb.userdb.users.delete_many({"username": "JoeUser"}) + self.rundb.userdb.user_cache.delete_many({"username": "JoeUser"}) + self.rundb.stop() + testing.tearDown() + + def test_login(self): + request = testing.DummyRequest( + userdb=self.rundb.userdb, + method="POST", + params={"username": "JoeUser", "password": "badsecret"}, + ) + response = login(request) + self.assertTrue("Invalid password" in request.session.pop_flash("error")) + + # Correct password, but still blocked from logging in + request.params["password"] = "secret" + login(request) + self.assertTrue("Blocked" in request.session.pop_flash("error")) + + # Unblock, then user can log in successfully + user = self.rundb.userdb.get_user("JoeUser") + user["blocked"] = False + self.rundb.userdb.save_user(user) + response = login(request) + self.assertEqual(response.code, 302) + self.assertTrue("The resource was found at" in str(response)) class Create90APITest(unittest.TestCase): - def setUp(self): - self.rundb = util.get_rundb() - self.run_id = self.rundb.new_run('master', 'master', 100000, - '100+0.01', 'book', 10, 1, '', '', - username='travis', tests_repo='travis', - start_time=datetime.datetime.utcnow()) - self.rundb.userdb.user_cache.insert_one({ - 'username': 'JoeUser', - 'cpu_hours': 12345 - }) - self.config = testing.setUp() - self.config.add_route('api_stop_run', '/api/stop_run') - - def tearDown(self): - self.rundb.userdb.users.delete_many({'username': 'JoeUser'}) - self.rundb.userdb.user_cache.delete_many({'username': 'JoeUser'}) - self.rundb.stop() - testing.tearDown() - - def test_stop_run(self): - request = testing.DummyRequest( - rundb=self.rundb, - userdb=self.rundb.userdb, - actiondb=self.rundb.actiondb, - method='POST', - json_body={ - 'username': 'JoeUser', - 'password': 'secret', - 'run_id': self.run_id, - 'message': 'travis' - } - ) - response = ApiView(request).stop_run() - self.assertEqual(response, {}) - run = request.rundb.get_run(request.json_body['run_id']) - self.assertEqual(run['stop_reason'], 'travis') + def setUp(self): + self.rundb = util.get_rundb() + self.run_id = self.rundb.new_run( + "master", + "master", + 100000, + "100+0.01", + "book", + 10, + 1, + "", + "", + username="travis", + tests_repo="travis", + start_time=datetime.datetime.utcnow(), + ) + self.rundb.userdb.user_cache.insert_one( + {"username": "JoeUser", "cpu_hours": 12345} + ) + self.config = testing.setUp() + self.config.add_route("api_stop_run", "/api/stop_run") + + def tearDown(self): + self.rundb.userdb.users.delete_many({"username": "JoeUser"}) + self.rundb.userdb.user_cache.delete_many({"username": "JoeUser"}) + self.rundb.stop() + testing.tearDown() + + def test_stop_run(self): + request = testing.DummyRequest( + rundb=self.rundb, + userdb=self.rundb.userdb, + actiondb=self.rundb.actiondb, + method="POST", + json_body={ + "username": "JoeUser", + "password": "secret", + "run_id": self.run_id, + "message": "travis", + }, + ) + response = ApiView(request).stop_run() + self.assertEqual(response, {}) + run = request.rundb.get_run(request.json_body["run_id"]) + self.assertEqual(run["stop_reason"], "travis") if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/fishtest/tests/util.py b/fishtest/tests/util.py index 351b1c766..1a55ae4c2 100644 --- a/fishtest/tests/util.py +++ b/fishtest/tests/util.py @@ -1,11 +1,13 @@ from fishtest.rundb import RunDb + def get_rundb(): - return RunDb(db_name='fishtest_tests') - -def find_run(arg='username', value='travis'): - rundb = RunDb(db_name='fishtest_tests') - for run in rundb.get_unfinished_runs(): - if run['args'][arg] == value: - return run - return None + return RunDb(db_name="fishtest_tests") + + +def find_run(arg="username", value="travis"): + rundb = RunDb(db_name="fishtest_tests") + for run in rundb.get_unfinished_runs(): + if run["args"][arg] == value: + return run + return None diff --git a/fishtest/utils/clone_fish.py b/fishtest/utils/clone_fish.py index 3889b6e86..052a29ce1 100644 --- a/fishtest/utils/clone_fish.py +++ b/fishtest/utils/clone_fish.py @@ -6,55 +6,59 @@ from pymongo import MongoClient, ASCENDING, DESCENDING from bson.binary import Binary -#fish_host = 'http://localhost:6543' -fish_host = 'http://94.198.98.239' # 'http://tests.stockfishchess.org' +# fish_host = 'http://localhost:6543' +fish_host = "http://94.198.98.239" # 'http://tests.stockfishchess.org' -conn = MongoClient('localhost') +conn = MongoClient("localhost") -#conn.drop_database('fish_clone') +# conn.drop_database('fish_clone') -db = conn['fish_clone'] +db = conn["fish_clone"] -pgndb = db['pgns'] -runs = db['runs'] +pgndb = db["pgns"] +runs = db["runs"] + +pgndb.ensure_index([("run_id", ASCENDING)]) -pgndb.ensure_index([('run_id', ASCENDING)]) def main(): - """clone a fishtest database with PGNs and runs with the REST API""" - - skip = 0 - count = 0 - in_sync = False - loaded = {} - while True: - pgn_list = requests.get(fish_host + '/api/pgn_100/' + str(skip)).json() - for pgn_file in pgn_list: - print(pgn_file) - if pgndb.find_one({'run_id': pgn_file}): - print('Already copied: %s' % (pgn_file)) - if not pgn_file in loaded: - in_sync = True - break - else: - run_id = pgn_file.split('-')[0] - if not runs.find_one({'_id': run_id}): - print('New run: ' + run_id) - run = requests.get(fish_host + '/api/get_run/' + run_id).json() - runs.insert(run) - pgn = requests.get(fish_host + '/api/pgn/' + pgn_file) - pgndb.insert(dict(pgn_bz2=Binary(bz2.compress(pgn.content)), run_id= pgn_file)) - loaded[pgn_file] = True - count += 1 - skip += len(pgn_list) - if in_sync or len(pgn_list) < 100: - break - - print('Copied: %6d PGN files (~ %8d games)' % (count, 250 * count)) - count = pgndb.count() - print('Database:%6d PGN files (~ %8d games)' % (count, 250 * count)) - count = runs.count() - print('Database:%6d runs' % (count)) - -if __name__ == '__main__': - main() + """clone a fishtest database with PGNs and runs with the REST API""" + + skip = 0 + count = 0 + in_sync = False + loaded = {} + while True: + pgn_list = requests.get(fish_host + "/api/pgn_100/" + str(skip)).json() + for pgn_file in pgn_list: + print(pgn_file) + if pgndb.find_one({"run_id": pgn_file}): + print("Already copied: %s" % (pgn_file)) + if not pgn_file in loaded: + in_sync = True + break + else: + run_id = pgn_file.split("-")[0] + if not runs.find_one({"_id": run_id}): + print("New run: " + run_id) + run = requests.get(fish_host + "/api/get_run/" + run_id).json() + runs.insert(run) + pgn = requests.get(fish_host + "/api/pgn/" + pgn_file) + pgndb.insert( + dict(pgn_bz2=Binary(bz2.compress(pgn.content)), run_id=pgn_file) + ) + loaded[pgn_file] = True + count += 1 + skip += len(pgn_list) + if in_sync or len(pgn_list) < 100: + break + + print("Copied: %6d PGN files (~ %8d games)" % (count, 250 * count)) + count = pgndb.count() + print("Database:%6d PGN files (~ %8d games)" % (count, 250 * count)) + count = runs.count() + print("Database:%6d runs" % (count)) + + +if __name__ == "__main__": + main() diff --git a/fishtest/utils/compact_actions.py b/fishtest/utils/compact_actions.py index 49eb0cc1e..a0e002676 100644 --- a/fishtest/utils/compact_actions.py +++ b/fishtest/utils/compact_actions.py @@ -9,30 +9,33 @@ from fishtest.actiondb import ActionDb conn = MongoClient() -db = conn['fishtest_new'] +db = conn["fishtest_new"] actiondb = ActionDb(db) + def compact_actions(): - for a in actiondb.actions.find(): - update = False - if 'tasks' in a['data']: - del a['data']['tasks'] - print(a['data']['_id']) - update = True - if 'before' in a['data']: - del a['data']['before']['tasks'] - print('before') - update = True - if 'after' in a['data']: - del a['data']['after']['tasks'] - print('after') - update = True - - if update: - actiondb.actions.replace_one({ '_id': a['_id'] }, a) + for a in actiondb.actions.find(): + update = False + if "tasks" in a["data"]: + del a["data"]["tasks"] + print(a["data"]["_id"]) + update = True + if "before" in a["data"]: + del a["data"]["before"]["tasks"] + print("before") + update = True + if "after" in a["data"]: + del a["data"]["after"]["tasks"] + print("after") + update = True + + if update: + actiondb.actions.replace_one({"_id": a["_id"]}, a) + def main(): - compact_actions() + compact_actions() + -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/fishtest/utils/create_indexes.py b/fishtest/utils/create_indexes.py index 239ad86ac..90cbe536c 100644 --- a/fishtest/utils/create_indexes.py +++ b/fishtest/utils/create_indexes.py @@ -9,119 +9,127 @@ from pymongo import MongoClient, ASCENDING, DESCENDING -db_name = 'fishtest_new' +db_name = "fishtest_new" # MongoDB server is assumed to be on the same machine, if not user should use # ssh with port forwarding to access the remote host. -conn = MongoClient(os.getenv('FISHTEST_HOST') or 'localhost') +conn = MongoClient(os.getenv("FISHTEST_HOST") or "localhost") db = conn[db_name] def create_runs_indexes(): - print('Creating indexes on runs collection') - db['runs'].create_index( - [('last_updated', ASCENDING)], - name='unfinished_runs', - partialFilterExpression={ 'finished': False } - ) - db['runs'].create_index( - [('last_updated', DESCENDING)], - name='finished_runs', - partialFilterExpression={ 'finished': True } - ) - db['runs'].create_index( - [('last_updated', DESCENDING), ('is_green', DESCENDING)], - name='finished_green_runs', - partialFilterExpression={ 'finished': True, 'is_green': True } - ) - db['runs'].create_index( - [('last_updated', DESCENDING), ('is_yellow', DESCENDING)], - name='finished_yellow_runs', - partialFilterExpression={ 'finished': True, 'is_yellow': True } - ) - db['runs'].create_index( - [('last_updated', DESCENDING), ('tc_base', DESCENDING)], - name='finished_ltc_runs', - partialFilterExpression={ 'finished': True, 'tc_base': { '$gte': 40 } } - ) - db['runs'].create_index( - [('args.username', DESCENDING), ('last_updated', DESCENDING)], - name='user_runs' - ) + print("Creating indexes on runs collection") + db["runs"].create_index( + [("last_updated", ASCENDING)], + name="unfinished_runs", + partialFilterExpression={"finished": False}, + ) + db["runs"].create_index( + [("last_updated", DESCENDING)], + name="finished_runs", + partialFilterExpression={"finished": True}, + ) + db["runs"].create_index( + [("last_updated", DESCENDING), ("is_green", DESCENDING)], + name="finished_green_runs", + partialFilterExpression={"finished": True, "is_green": True}, + ) + db["runs"].create_index( + [("last_updated", DESCENDING), ("is_yellow", DESCENDING)], + name="finished_yellow_runs", + partialFilterExpression={"finished": True, "is_yellow": True}, + ) + db["runs"].create_index( + [("last_updated", DESCENDING), ("tc_base", DESCENDING)], + name="finished_ltc_runs", + partialFilterExpression={"finished": True, "tc_base": {"$gte": 40}}, + ) + db["runs"].create_index( + [("args.username", DESCENDING), ("last_updated", DESCENDING)], name="user_runs" + ) + def create_pgns_indexes(): - print('Creating indexes on pgns collection') - db['pgns'].create_index([('run_id', DESCENDING)]) + print("Creating indexes on pgns collection") + db["pgns"].create_index([("run_id", DESCENDING)]) + def create_nns_indexes(): - print('Creating indexes on nns collection') - db['nns'].create_index([('name', DESCENDING)]) + print("Creating indexes on nns collection") + db["nns"].create_index([("name", DESCENDING)]) + def create_users_indexes(): - db['users'].create_index('username', unique=True) + db["users"].create_index("username", unique=True) + def create_actions_indexes(): - db['actions'].create_index([ - ('username', ASCENDING), - ('_id', DESCENDING), - ]) - db['actions'].create_index([ - ('action', ASCENDING), - ('_id', DESCENDING), - ]) + db["actions"].create_index([("username", ASCENDING), ("_id", DESCENDING)]) + db["actions"].create_index([("action", ASCENDING), ("_id", DESCENDING)]) + def create_flag_cache_indexes(): - db['flag_cache'].create_index('ip') - db['flag_cache'].create_index('geoip_checked_at', expireAfterSeconds=60 * 60 * 24 * 7) + db["flag_cache"].create_index("ip") + db["flag_cache"].create_index( + "geoip_checked_at", expireAfterSeconds=60 * 60 * 24 * 7 + ) + def print_current_indexes(): - for collection_name in db.collection_names(): - c = db[collection_name] - print('Current indexes on ' + collection_name + ':') - pprint.pprint(c.index_information(), stream=None, indent=2, width=110, depth=None) - print('') + for collection_name in db.collection_names(): + c = db[collection_name] + print("Current indexes on " + collection_name + ":") + pprint.pprint( + c.index_information(), stream=None, indent=2, width=110, depth=None + ) + print("") + def drop_indexes(collection_name): - # Drop all indexes on collection except _id_ - print('\nDropping indexes on {}'.format(collection_name)) - collection = db[collection_name] - index_keys = list(collection.index_information().keys()) - print('Current indexes: {}'.format(index_keys)) - for idx in index_keys: - if idx != '_id_': - print('Dropping ' + collection_name + ' index ' + idx + ' ...') - collection.drop_index(idx) - - -if __name__ == '__main__': - # Takes a list of collection names as arguments. - # For each collection name, this script drops indexes and re-creates them. - # With no argument, indexes are printed, but no indexes are re-created. - collection_names = sys.argv[1:] - if collection_names: - print('Re-creating indexes...') - for collection_name in collection_names: - if collection_name == 'users': - drop_indexes('users') - create_users_indexes() - elif collection_name == 'actions': - drop_indexes('actions') - create_actions_indexes() - elif collection_name == 'runs': - drop_indexes('runs') - create_runs_indexes() - elif collection_name == 'pgns': - drop_indexes('pgns') - create_pgns_indexes() - elif collection_name == 'nns': - drop_indexes('nns') - create_nns_indexes() - elif collection_name == 'flag_cache': - drop_indexes('flag_cache') - create_flag_cache_indexes() - print('Finished creating indexes!\n') - print_current_indexes() - if not collection_names: - print('Collections in {}: {}'.format(db_name, db.collection_names())) - print('Give a list of collection names as arguments to re-create indexes. For example:\n') - print(' python3 create_indexes.py users runs - drops and creates indexes for runs and users\n') + # Drop all indexes on collection except _id_ + print("\nDropping indexes on {}".format(collection_name)) + collection = db[collection_name] + index_keys = list(collection.index_information().keys()) + print("Current indexes: {}".format(index_keys)) + for idx in index_keys: + if idx != "_id_": + print("Dropping " + collection_name + " index " + idx + " ...") + collection.drop_index(idx) + + +if __name__ == "__main__": + # Takes a list of collection names as arguments. + # For each collection name, this script drops indexes and re-creates them. + # With no argument, indexes are printed, but no indexes are re-created. + collection_names = sys.argv[1:] + if collection_names: + print("Re-creating indexes...") + for collection_name in collection_names: + if collection_name == "users": + drop_indexes("users") + create_users_indexes() + elif collection_name == "actions": + drop_indexes("actions") + create_actions_indexes() + elif collection_name == "runs": + drop_indexes("runs") + create_runs_indexes() + elif collection_name == "pgns": + drop_indexes("pgns") + create_pgns_indexes() + elif collection_name == "nns": + drop_indexes("nns") + create_nns_indexes() + elif collection_name == "flag_cache": + drop_indexes("flag_cache") + create_flag_cache_indexes() + print("Finished creating indexes!\n") + print_current_indexes() + if not collection_names: + print("Collections in {}: {}".format(db_name, db.collection_names())) + print( + "Give a list of collection names as arguments to re-create indexes. For example:\n" + ) + print( + " python3 create_indexes.py users runs - drops and creates indexes for runs and users\n" + ) diff --git a/fishtest/utils/create_pgndb.py b/fishtest/utils/create_pgndb.py index 556e66b43..da6c8de0f 100644 --- a/fishtest/utils/create_pgndb.py +++ b/fishtest/utils/create_pgndb.py @@ -2,10 +2,10 @@ from pymongo import MongoClient, ASCENDING, DESCENDING -conn = MongoClient('localhost') +conn = MongoClient("localhost") -db = conn['fishtest_new'] +db = conn["fishtest_new"] -db.drop_collection('pgns') +db.drop_collection("pgns") -db.create_collection('pgns', capped=True, size=50000) +db.create_collection("pgns", capped=True, size=50000) diff --git a/fishtest/utils/current.py b/fishtest/utils/current.py index 6111a7407..e3763e976 100644 --- a/fishtest/utils/current.py +++ b/fishtest/utils/current.py @@ -15,17 +15,19 @@ import sys from pymongo import MongoClient, ASCENDING, DESCENDING -db_name='fishtest_new' +db_name = "fishtest_new" # MongoDB server is assumed to be on the same machine, if not user should use # ssh with port forwarding to access the remote host. -conn = MongoClient(os.getenv('FISHTEST_HOST') or 'localhost') +conn = MongoClient(os.getenv("FISHTEST_HOST") or "localhost") db = conn[db_name] -runs = db['runs'] +runs = db["runs"] + def printout(s): - print(s) - sys.stdout.flush() + print(s) + sys.stdout.flush() + # display current list of indexes printout("Current Indexes:") @@ -39,7 +41,11 @@ def printout(s): printout("\nCurrent operations:") t = 0.3 if len(sys.argv) > 1: - t = float(sys.argv[1]) -pprint.pprint(db.current_op({'secs_running': {'$gte': t}, 'query': {'$ne': {}}}), - stream=None, indent=1, width=80, depth=None) - + t = float(sys.argv[1]) +pprint.pprint( + db.current_op({"secs_running": {"$gte": t}, "query": {"$ne": {}}}), + stream=None, + indent=1, + width=80, + depth=None, +) diff --git a/fishtest/utils/delta_update_users.py b/fishtest/utils/delta_update_users.py index 2437a4854..ba520b0ea 100644 --- a/fishtest/utils/delta_update_users.py +++ b/fishtest/utils/delta_update_users.py @@ -7,7 +7,7 @@ from pymongo import DESCENDING # For tasks -sys.path.append(os.path.expanduser('~/fishtest/fishtest')) +sys.path.append(os.path.expanduser("~/fishtest/fishtest")) from fishtest.rundb import RunDb from fishtest.util import estimate_game_duration, delta_date @@ -15,185 +15,205 @@ new_deltas = {} skip = False + def process_run(run, info, deltas=None): - global skip - if deltas and (skip or str(run['_id']) in deltas): - skip = True - return - if deltas != None and str(run['_id']) in new_deltas: - print('Warning: skipping repeated run!') - return - if 'username' in run['args']: - username = run['args']['username'] - if username not in info: - print('not in info: ', username) - return - else: - info[username]['tests'] += 1 - - tc = estimate_game_duration(run['args']['tc']) - for task in run['tasks']: - if 'worker_info' not in task: - continue - username = task['worker_info'].get('username', None) - if username == None: - continue - if username not in info: - print('not in info: ', username) - continue - - if 'stats' in task: - stats = task['stats'] - num_games = stats['wins'] + stats['losses'] + stats['draws'] - else: - num_games = 0 + global skip + if deltas and (skip or str(run["_id"]) in deltas): + skip = True + return + if deltas != None and str(run["_id"]) in new_deltas: + print("Warning: skipping repeated run!") + return + if "username" in run["args"]: + username = run["args"]["username"] + if username not in info: + print("not in info: ", username) + return + else: + info[username]["tests"] += 1 + + tc = estimate_game_duration(run["args"]["tc"]) + for task in run["tasks"]: + if "worker_info" not in task: + continue + username = task["worker_info"].get("username", None) + if username == None: + continue + if username not in info: + print("not in info: ", username) + continue + + if "stats" in task: + stats = task["stats"] + num_games = stats["wins"] + stats["losses"] + stats["draws"] + else: + num_games = 0 - try: - info[username]['last_updated'] = max(task['last_updated'], info[username]['last_updated']) - info[username]['task_last_updated'] = max(task['last_updated'], info[username]['last_updated']) - except: - info[username]['last_updated'] = task['last_updated'] + try: + info[username]["last_updated"] = max( + task["last_updated"], info[username]["last_updated"] + ) + info[username]["task_last_updated"] = max( + task["last_updated"], info[username]["last_updated"] + ) + except: + info[username]["last_updated"] = task["last_updated"] + + info[username]["cpu_hours"] += float( + num_games * int(run["args"].get("threads", 1)) * tc / (60 * 60) + ) + info[username]["games"] += num_games + if deltas != None: + new_deltas.update({str(run["_id"]): None}) - info[username]['cpu_hours'] += float(num_games * int(run['args'].get('threads', 1)) * tc / (60 * 60)) - info[username]['games'] += num_games - if deltas != None: - new_deltas.update({ str(run['_id']) : None}) def build_users(machines, info): - for machine in machines: - games_per_hour = (machine['nps'] / 1600000.0) * (3600.0 / estimate_game_duration(machine['run']['args']['tc'])) * (int(machine['concurrency']) // machine['run']['args'].get('threads', 1)) - info[machine['username']]['games_per_hour'] += games_per_hour - - users = [] - for u in info.keys(): - user = info[u] - try: - if isinstance(user['last_updated'], str): - user['last_updated'] = delta_date(user['task_last_updated']) - else: - user['last_updated'] = delta_date(user['last_updated']) - except: - pass - users.append(user) - - users = [u for u in users if u['games'] > 0 or u['tests'] > 0] - return users + for machine in machines: + games_per_hour = ( + (machine["nps"] / 1600000.0) + * (3600.0 / estimate_game_duration(machine["run"]["args"]["tc"])) + * (int(machine["concurrency"]) // machine["run"]["args"].get("threads", 1)) + ) + info[machine["username"]]["games_per_hour"] += games_per_hour + + users = [] + for u in info.keys(): + user = info[u] + try: + if isinstance(user["last_updated"], str): + user["last_updated"] = delta_date(user["task_last_updated"]) + else: + user["last_updated"] = delta_date(user["last_updated"]) + except: + pass + users.append(user) + + users = [u for u in users if u["games"] > 0 or u["tests"] > 0] + return users + def update_users(): - rundb = RunDb() - - deltas = {} - info = {} - top_month = {} - - clear_stats = True - if len(sys.argv) > 1: - print('scan all') - else: - deltas = rundb.deltas.find_one() - if deltas: - clear_stats = False - else: - deltas = {} - - for u in rundb.userdb.get_users(): - username = u['username'] - top_month[username] = {'username': username, - 'cpu_hours': 0, - 'games': 0, - 'tests': 0, - 'tests_repo': u.get('tests_repo', ''), - 'last_updated': datetime.min, - 'games_per_hour': 0.0,} - if clear_stats: - info[username] = top_month[username].copy() + rundb = RunDb() + + deltas = {} + info = {} + top_month = {} + + clear_stats = True + if len(sys.argv) > 1: + print("scan all") else: - info[username] = rundb.userdb.user_cache.find_one({'username': username}) - if not info[username]: - info[username] = top_month[username].copy() - else: - info[username]['games_per_hour'] = 0.0 - - for run in rundb.get_unfinished_runs(): - try: - process_run(run, top_month) - except: - print("Exception on run: ", run) - - # Step through these in small batches (step size 100) to save RAM - step_size = 100 - - now = datetime.utcnow() - more_days = True - last_updated = None - while more_days: - q = { 'finished': True } - if last_updated: - q['last_updated'] = { '$lt': last_updated } - runs = list(rundb.runs.find(q, sort=[('last_updated', DESCENDING)], limit=step_size)) - if len(runs) == 0: - break - for run in runs: - try: - process_run(run, info, deltas) - except: - print("Exception on run: ", run['_id']) - if (now - run['start_time']).days < 30: + deltas = rundb.deltas.find_one() + if deltas: + clear_stats = False + else: + deltas = {} + + for u in rundb.userdb.get_users(): + username = u["username"] + top_month[username] = { + "username": username, + "cpu_hours": 0, + "games": 0, + "tests": 0, + "tests_repo": u.get("tests_repo", ""), + "last_updated": datetime.min, + "games_per_hour": 0.0, + } + if clear_stats: + info[username] = top_month[username].copy() + else: + info[username] = rundb.userdb.user_cache.find_one({"username": username}) + if not info[username]: + info[username] = top_month[username].copy() + else: + info[username]["games_per_hour"] = 0.0 + + for run in rundb.get_unfinished_runs(): try: - process_run(run, top_month) + process_run(run, top_month) except: - print("Exception on run: ", run['_id']) - elif not clear_stats: - more_days = False - last_updated = runs[-1]['last_updated'] - - if new_deltas: - new_deltas.update(deltas) - rundb.deltas.remove() - rundb.deltas.save(new_deltas) - - machines = rundb.get_machines() - - users = build_users(machines, info) - rundb.userdb.user_cache.remove() - rundb.userdb.user_cache.insert(users) - rundb.userdb.user_cache.create_index('username', unique=True) - - rundb.userdb.top_month.remove() - rundb.userdb.top_month.insert(build_users(machines, top_month)) - - # Delete users that have never been active and old admins group - idle = {} - for u in rundb.userdb.get_users(): - update = False - while 'group:admins' in u['groups']: - u['groups'].remove('group:admins') - update = True - if update: - rundb.userdb.users.save(u) - if not 'registration_time' in u \ - or u['registration_time'] < datetime.utcnow() - timedelta(days=28): - idle[u['username']] = u - for u in rundb.userdb.user_cache.find(): - if u['username'] in idle: - del idle[u['username']] - for u in idle.values(): - # A safe guard against deleting long time users - if not 'registration_time' in u \ - or u['registration_time'] < datetime.utcnow() - timedelta(days=38): - print('Warning: Found old user to delete: ' + str(u['_id'])) - else: - print('Delete: ' + str(u['_id'])) - rundb.userdb.users.remove({'_id': u['_id']}) - - print('Successfully updated %d users' % (len(users))) - - # record this update run - rundb.actiondb.update_stats() + print("Exception on run: ", run) + + # Step through these in small batches (step size 100) to save RAM + step_size = 100 + + now = datetime.utcnow() + more_days = True + last_updated = None + while more_days: + q = {"finished": True} + if last_updated: + q["last_updated"] = {"$lt": last_updated} + runs = list( + rundb.runs.find(q, sort=[("last_updated", DESCENDING)], limit=step_size) + ) + if len(runs) == 0: + break + for run in runs: + try: + process_run(run, info, deltas) + except: + print("Exception on run: ", run["_id"]) + if (now - run["start_time"]).days < 30: + try: + process_run(run, top_month) + except: + print("Exception on run: ", run["_id"]) + elif not clear_stats: + more_days = False + last_updated = runs[-1]["last_updated"] + + if new_deltas: + new_deltas.update(deltas) + rundb.deltas.remove() + rundb.deltas.save(new_deltas) + + machines = rundb.get_machines() + + users = build_users(machines, info) + rundb.userdb.user_cache.remove() + rundb.userdb.user_cache.insert(users) + rundb.userdb.user_cache.create_index("username", unique=True) + + rundb.userdb.top_month.remove() + rundb.userdb.top_month.insert(build_users(machines, top_month)) + + # Delete users that have never been active and old admins group + idle = {} + for u in rundb.userdb.get_users(): + update = False + while "group:admins" in u["groups"]: + u["groups"].remove("group:admins") + update = True + if update: + rundb.userdb.users.save(u) + if not "registration_time" in u or u[ + "registration_time" + ] < datetime.utcnow() - timedelta(days=28): + idle[u["username"]] = u + for u in rundb.userdb.user_cache.find(): + if u["username"] in idle: + del idle[u["username"]] + for u in idle.values(): + # A safe guard against deleting long time users + if not "registration_time" in u or u[ + "registration_time" + ] < datetime.utcnow() - timedelta(days=38): + print("Warning: Found old user to delete: " + str(u["_id"])) + else: + print("Delete: " + str(u["_id"])) + rundb.userdb.users.remove({"_id": u["_id"]}) + + print("Successfully updated %d users" % (len(users))) + + # record this update run + rundb.actiondb.update_stats() def main(): - update_users() + update_users() + -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/fishtest/utils/index_pending.py b/fishtest/utils/index_pending.py index d79a48098..106e22ce5 100644 --- a/fishtest/utils/index_pending.py +++ b/fishtest/utils/index_pending.py @@ -15,25 +15,27 @@ import pprint from pymongo import MongoClient, ASCENDING, DESCENDING -db_name='fishtest_new' +db_name = "fishtest_new" # MongoDB server is assumed to be on the same machine, if not user should use # ssh with port forwarding to access the remote host. -conn = MongoClient(os.getenv('FISHTEST_HOST') or 'localhost') +conn = MongoClient(os.getenv("FISHTEST_HOST") or "localhost") db = conn[db_name] -runs = db['runs'] -pgns = db['pgns'] +runs = db["runs"] +pgns = db["pgns"] + def printout(s): - print(s) - sys.stdout.flush() + print(s) + sys.stdout.flush() + # create indexes: -#printout("Creating index ...") -#runs.create_index([('tasks.pending', ASCENDING)]) +# printout("Creating index ...") +# runs.create_index([('tasks.pending', ASCENDING)]) printout("Creating pgn index ...") -pgns.ensure_index([('run_id', ASCENDING)]) +pgns.ensure_index([("run_id", ASCENDING)]) # IF INDEX NEEDS TO BE DROPPED, COMMENT OUT ABOVE 2 LINES, AND UNCOMMENT NEXT 2: # printout("\nDropping index ...") diff --git a/fishtest/utils/purge_pgn.py b/fishtest/utils/purge_pgn.py index 33f8399d1..27b6620c5 100644 --- a/fishtest/utils/purge_pgn.py +++ b/fishtest/utils/purge_pgn.py @@ -12,49 +12,57 @@ rundb = RunDb() + def purge_pgn(days): - """Purge old PGNs except LTC (>= 20s) runs""" - - deleted_runs = 0 - deleted_tasks = 0 - saved_runs = 0 - saved_tasks = 0 - now = datetime.utcnow() - - run_count = 0 - for run in rundb.runs.find({'finished': True, 'deleted': {'$exists': False}}, - sort=[('last_updated', DESCENDING)]): - - if (now - run['start_time']).days > 30: - break - - run_count += 1 - if run_count % 10 == 0: - print('Run: %05d' % (run_count), end='\r') - - skip = False - if (re.match('^([2-9][0-9])|(1[0-9][0-9])', run['args']['tc']) \ - and run['last_updated'] > datetime.utcnow() - timedelta(days=5*days)) \ - or run['last_updated'] > datetime.utcnow() - timedelta(days=days): - saved_runs += 1 - skip = True - else: - deleted_runs += 1 - - for idx, task in enumerate(run['tasks']): - key = str(run['_id']) + '-' + str(idx) - for pgn in rundb.pgndb.find({'run_id': key}): # We can have multiple PGNs per task - if skip: - saved_tasks += 1 + """Purge old PGNs except LTC (>= 20s) runs""" + + deleted_runs = 0 + deleted_tasks = 0 + saved_runs = 0 + saved_tasks = 0 + now = datetime.utcnow() + + run_count = 0 + for run in rundb.runs.find( + {"finished": True, "deleted": {"$exists": False}}, + sort=[("last_updated", DESCENDING)], + ): + + if (now - run["start_time"]).days > 30: + break + + run_count += 1 + if run_count % 10 == 0: + print("Run: %05d" % (run_count), end="\r") + + skip = False + if ( + re.match("^([2-9][0-9])|(1[0-9][0-9])", run["args"]["tc"]) + and run["last_updated"] > datetime.utcnow() - timedelta(days=5 * days) + ) or run["last_updated"] > datetime.utcnow() - timedelta(days=days): + saved_runs += 1 + skip = True else: - rundb.pgndb.remove({'_id': pgn['_id']}) - deleted_tasks += 1 + deleted_runs += 1 + + for idx, task in enumerate(run["tasks"]): + key = str(run["_id"]) + "-" + str(idx) + for pgn in rundb.pgndb.find( + {"run_id": key} + ): # We can have multiple PGNs per task + if skip: + saved_tasks += 1 + else: + rundb.pgndb.remove({"_id": pgn["_id"]}) + deleted_tasks += 1 + + print("PGN runs/tasks saved: %5d/%7d" % (saved_runs, saved_tasks)) + print("PGN runs/tasks purged: %5d/%7d" % (deleted_runs, deleted_tasks)) - print('PGN runs/tasks saved: %5d/%7d' % (saved_runs, saved_tasks)) - print('PGN runs/tasks purged: %5d/%7d' % (deleted_runs, deleted_tasks)) def main(): - purge_pgn(days=2) + purge_pgn(days=2) + -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/fishtest/utils/scavenge.py b/fishtest/utils/scavenge.py index ce7fbe211..d2076d024 100644 --- a/fishtest/utils/scavenge.py +++ b/fishtest/utils/scavenge.py @@ -7,25 +7,30 @@ from datetime import datetime, timedelta # For tasks -sys.path.append(os.path.expanduser('~/fishtest/fishtest')) +sys.path.append(os.path.expanduser("~/fishtest/fishtest")) from fishtest.rundb import RunDb rundb = RunDb() + def scavenge_tasks(scavenge=True, minutes=60): - """Check for tasks that have not been updated recently""" - for run in rundb.runs.find({'tasks': {'$elemMatch': {'active': True}}}): - changed = False - for idx, task in enumerate(run['tasks']): - if task['active'] and task['last_updated'] < datetime.utcnow() - timedelta(minutes=minutes): - print('Scavenging', task) - task['active'] = False - changed = True - if changed and scavenge: - rundb.runs.save(run) + """Check for tasks that have not been updated recently""" + for run in rundb.runs.find({"tasks": {"$elemMatch": {"active": True}}}): + changed = False + for idx, task in enumerate(run["tasks"]): + if task["active"] and task["last_updated"] < datetime.utcnow() - timedelta( + minutes=minutes + ): + print("Scavenging", task) + task["active"] = False + changed = True + if changed and scavenge: + rundb.runs.save(run) + def main(): - scavenge_tasks(scavenge=True) + scavenge_tasks(scavenge=True) + -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/fishtest/utils/test_queries.py b/fishtest/utils/test_queries.py index 76da7951c..519e0486b 100644 --- a/fishtest/utils/test_queries.py +++ b/fishtest/utils/test_queries.py @@ -1,4 +1,3 @@ - # test_queries.py - run some sample queries to check db speed # @@ -10,24 +9,27 @@ import time from pymongo import MongoClient, ASCENDING, DESCENDING -sys.path.append(os.path.expanduser('~/fishtest/fishtest')) +sys.path.append(os.path.expanduser("~/fishtest/fishtest")) from fishtest.rundb import RunDb -db_name='fishtest_new' +db_name = "fishtest_new" rundb = RunDb() # MongoDB server is assumed to be on the same machine, if not user should use # ssh with port forwarding to access the remote host. -conn = MongoClient(os.getenv('FISHTEST_HOST') or 'localhost') +conn = MongoClient(os.getenv("FISHTEST_HOST") or "localhost") db = conn[db_name] -runs = db['runs'] -pgns = db['pgns'] +runs = db["runs"] +pgns = db["pgns"] def qlen(c): - if (c): return len(list(c)) - else : return 0 + if c: + return len(list(c)) + else: + return 0 + # Extra conditions that might be applied to finished_runs: # q['args.username'] = username @@ -42,55 +44,61 @@ def qlen(c): c = rundb.get_unfinished_runs() end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s\nFetching machines ...") +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s\nFetching machines ...") start = time.time() c = rundb.get_machines() end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s\nFetching finished runs ...") +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s\nFetching finished runs ...") start = time.time() -c, n = rundb.get_finished_runs(skip=0, limit=50, username='', - success_only=False, ltc_only=False) +c, n = rundb.get_finished_runs( + skip=0, limit=50, username="", success_only=False, ltc_only=False +) end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s\nFetching finished runs (vdv) ...") +print( + "{} rows {:1.4f}".format(qlen(c), end - start) + + "s\nFetching finished runs (vdv) ..." +) start = time.time() -c, n = rundb.get_finished_runs(skip=0, limit=50, username='vdv', - success_only=False, ltc_only=False) +c, n = rundb.get_finished_runs( + skip=0, limit=50, username="vdv", success_only=False, ltc_only=False +) end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s\nRequesting pgn ...") -if (n == 0): - c.append({'_id':'abc'}) +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s\nRequesting pgn ...") +if n == 0: + c.append({"_id": "abc"}) start = time.time() -c = rundb.get_pgn(str(c[0]['_id']) + ".pgn") +c = rundb.get_pgn(str(c[0]["_id"]) + ".pgn") end = time.time() # Tests: Explain some queries - should show which indexes are being used -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s") +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s") print("\n\nExplain queries") print("\nFetching unfinished runs xp ...") start = time.time() -c = runs.find({'finished': False}, sort=[('last_updated', DESCENDING), ('start_time', DESCENDING)]).explain() +c = runs.find( + {"finished": False}, sort=[("last_updated", DESCENDING), ("start_time", DESCENDING)] +).explain() print(pprint.pformat(c, indent=3, width=110)) end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s") +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s") print("\nFetching machines xp ...") start = time.time() -c = runs.find({'finished': False, 'tasks': {'$elemMatch': {'active': True}}}).explain() +c = runs.find({"finished": False, "tasks": {"$elemMatch": {"active": True}}}).explain() print(pprint.pformat(c, indent=3, width=110)) end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s") +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s") print("\nFetching finished runs xp ...") start = time.time() -q = {'finished': True} -c = runs.find(q, skip=0, limit=50, sort=[('last_updated', DESCENDING)]).explain() +q = {"finished": True} +c = runs.find(q, skip=0, limit=50, sort=[("last_updated", DESCENDING)]).explain() print(pprint.pformat(c, indent=3, width=110)) end = time.time() -print("{} rows {:1.4f}".format(qlen(c), end-start) + "s") - +print("{} rows {:1.4f}".format(qlen(c), end - start) + "s") diff --git a/worker/games.py b/worker/games.py index e813a6402..49ad6213d 100644 --- a/worker/games.py +++ b/worker/games.py @@ -23,758 +23,1016 @@ from zipfile import ZipFile try: - from Queue import Queue, Empty + from Queue import Queue, Empty except ImportError: - from queue import Queue, Empty # python 3.x + from queue import Queue, Empty # python 3.x -IS_WINDOWS = 'windows' in platform.system().lower() +IS_WINDOWS = "windows" in platform.system().lower() + +ARCH = "?" -ARCH = '?' def is_windows_64bit(): - if 'PROCESSOR_ARCHITEW6432' in os.environ: - return True - return os.environ['PROCESSOR_ARCHITECTURE'].endswith('64') + if "PROCESSOR_ARCHITEW6432" in os.environ: + return True + return os.environ["PROCESSOR_ARCHITECTURE"].endswith("64") + def is_64bit(): - if IS_WINDOWS: - return is_windows_64bit() - return '64' in platform.architecture()[0] + if IS_WINDOWS: + return is_windows_64bit() + return "64" in platform.architecture()[0] + HTTP_TIMEOUT = 15.0 -REPO_URL = 'https://github.com/official-stockfish/books' -EXE_SUFFIX = '.exe' if IS_WINDOWS else '' -MAKE_CMD = 'make COMP=mingw ' if IS_WINDOWS else 'make COMP=gcc ' +REPO_URL = "https://github.com/official-stockfish/books" +EXE_SUFFIX = ".exe" if IS_WINDOWS else "" +MAKE_CMD = "make COMP=mingw " if IS_WINDOWS else "make COMP=gcc " + def send_api_post_request(api_url, payload): - return requests.post(api_url, data=json.dumps(payload), headers={ - 'Content-Type': 'application/json' - }, timeout=HTTP_TIMEOUT) + return requests.post( + api_url, + data=json.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=HTTP_TIMEOUT, + ) + def github_api(repo): - """ Convert from https://github.com// + """ Convert from https://github.com// To https://api.github.com/repos// """ - return repo.replace('https://github.com', 'https://api.github.com/repos') + return repo.replace("https://github.com", "https://api.github.com/repos") + def required_net(engine): - net = None - print('Obtaining EvalFile of %s ...' % (os.path.basename(engine))) - p = subprocess.Popen([engine, 'uci'], stdout=subprocess.PIPE, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) + net = None + print("Obtaining EvalFile of %s ..." % (os.path.basename(engine))) + p = subprocess.Popen( + [engine, "uci"], + stdout=subprocess.PIPE, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + + for line in iter(p.stdout.readline, ""): + if "EvalFile" in line: + net = line.split(" ")[6].strip() - for line in iter(p.stdout.readline,''): - if 'EvalFile' in line: - net = line.split(' ')[6].strip() + p.wait() + p.stdout.close() - p.wait() - p.stdout.close() + if p.returncode != 0: + raise Exception("uci exited with non-zero code %d" % (p.returncode)) - if p.returncode != 0: - raise Exception('uci exited with non-zero code %d' % (p.returncode)) + return net - return net def required_net_from_source(): - """ Parse ucioption.cpp to find default net""" - net = None - with open('ucioption.cpp','r') as srcfile: - for line in srcfile.readlines(): - if 'EvalFile' in line and 'Option' in line: - p = re.compile('nn-[a-z0-9]{12}.nnue') - m = p.search(line) - if m: - net = m.group(0) + """ Parse ucioption.cpp to find default net""" + net = None + with open("ucioption.cpp", "r") as srcfile: + for line in srcfile.readlines(): + if "EvalFile" in line and "Option" in line: + p = re.compile("nn-[a-z0-9]{12}.nnue") + m = p.search(line) + if m: + net = m.group(0) - return net + return net def download_net(remote, testing_dir, net): - url = remote + '/api/nn/' + net - r = requests.get(url, allow_redirects=True) - open(os.path.join(testing_dir, net), 'wb').write(r.content) + url = remote + "/api/nn/" + net + r = requests.get(url, allow_redirects=True) + open(os.path.join(testing_dir, net), "wb").write(r.content) + def validate_net(testing_dir, net): - content = open(os.path.join(testing_dir, net), "rb").read() - hash = hashlib.sha256(content).hexdigest() - return hash[:12] == net[3:15] + content = open(os.path.join(testing_dir, net), "rb").read() + hash = hashlib.sha256(content).hexdigest() + return hash[:12] == net[3:15] + def verify_signature(engine, signature, remote, payload, concurrency): - global ARCH - if concurrency > 1: - with open(os.devnull, 'wb') as dev_null: - busy_process = subprocess.Popen([engine], stdin=subprocess.PIPE, stdout=dev_null, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) - busy_process.stdin.write('setoption name Threads value {:.0f}\n'.format(concurrency-1)) - busy_process.stdin.write('go infinite\n') - busy_process.stdin.flush() - - try: - bench_sig = '' - print('Verifying signature of %s ...' % (os.path.basename(engine))) - with open(os.devnull, 'wb') as dev_null: - p = subprocess.Popen([engine, 'bench'], stderr=subprocess.PIPE, stdout=dev_null, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) - p2 = subprocess.Popen([engine, 'compiler'], stdout=subprocess.PIPE, stderr=dev_null, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) - - for line in iter(p.stderr.readline, ''): - if 'Nodes searched' in line: - bench_sig = line.split(': ')[1].strip() - if 'Nodes/second' in line: - bench_nps = float(line.split(': ')[1].strip()) - p.wait() - p.stderr.close() - - for line in iter(p2.stdout.readline, ''): - if 'settings' in line: - ARCH = line.split(': ')[1].strip() - p2.wait() - p2.stdout.close() - - if p.returncode: - raise Exception('Bench exited with non-zero code %d' % (p.returncode)) - if p2.returncode: - raise Exception('Compiler info exited with non-zero code %d' % (p2.returncode)) - - if int(bench_sig) != int(signature): - message = 'Wrong bench in %s Expected: %s Got: %s' % (os.path.basename(engine), signature, bench_sig) - payload['message'] = message - send_api_post_request(remote + '/api/stop_run', payload) - raise Exception(message) - - finally: + global ARCH if concurrency > 1: - busy_process.communicate('quit\n') - busy_process.stdin.close() + with open(os.devnull, "wb") as dev_null: + busy_process = subprocess.Popen( + [engine], + stdin=subprocess.PIPE, + stdout=dev_null, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + busy_process.stdin.write( + "setoption name Threads value {:.0f}\n".format(concurrency - 1) + ) + busy_process.stdin.write("go infinite\n") + busy_process.stdin.flush() + + try: + bench_sig = "" + print("Verifying signature of %s ..." % (os.path.basename(engine))) + with open(os.devnull, "wb") as dev_null: + p = subprocess.Popen( + [engine, "bench"], + stderr=subprocess.PIPE, + stdout=dev_null, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + p2 = subprocess.Popen( + [engine, "compiler"], + stdout=subprocess.PIPE, + stderr=dev_null, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + + for line in iter(p.stderr.readline, ""): + if "Nodes searched" in line: + bench_sig = line.split(": ")[1].strip() + if "Nodes/second" in line: + bench_nps = float(line.split(": ")[1].strip()) + p.wait() + p.stderr.close() + + for line in iter(p2.stdout.readline, ""): + if "settings" in line: + ARCH = line.split(": ")[1].strip() + p2.wait() + p2.stdout.close() + + if p.returncode: + raise Exception("Bench exited with non-zero code %d" % (p.returncode)) + if p2.returncode: + raise Exception( + "Compiler info exited with non-zero code %d" % (p2.returncode) + ) + + if int(bench_sig) != int(signature): + message = "Wrong bench in %s Expected: %s Got: %s" % ( + os.path.basename(engine), + signature, + bench_sig, + ) + payload["message"] = message + send_api_post_request(remote + "/api/stop_run", payload) + raise Exception(message) + + finally: + if concurrency > 1: + busy_process.communicate("quit\n") + busy_process.stdin.close() + + return bench_nps - return bench_nps def setup(item, testing_dir): - """Download item from FishCooking to testing_dir""" - tree = requests.get(github_api(REPO_URL) + '/git/trees/master', timeout=HTTP_TIMEOUT).json() - for blob in tree['tree']: - if blob['path'] == item: - print('Downloading %s ...' % (item)) - blob_json = requests.get(blob['url'], timeout=HTTP_TIMEOUT).json() - with open(os.path.join(testing_dir, item), 'wb+') as f: - f.write(b64decode(blob_json['content'])) - break - else: - raise Exception('Item %s not found' % (item)) + """Download item from FishCooking to testing_dir""" + tree = requests.get( + github_api(REPO_URL) + "/git/trees/master", timeout=HTTP_TIMEOUT + ).json() + for blob in tree["tree"]: + if blob["path"] == item: + print("Downloading %s ..." % (item)) + blob_json = requests.get(blob["url"], timeout=HTTP_TIMEOUT).json() + with open(os.path.join(testing_dir, item), "wb+") as f: + f.write(b64decode(blob_json["content"])) + break + else: + raise Exception("Item %s not found" % (item)) + def gcc_props(): - """Parse the output of g++ -Q -march=native --help=target and extract the available properties""" - p = subprocess.Popen(['g++', '-Q', '-march=native', '--help=target'], stdout=subprocess.PIPE, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) + """Parse the output of g++ -Q -march=native --help=target and extract the available properties""" + p = subprocess.Popen( + ["g++", "-Q", "-march=native", "--help=target"], + stdout=subprocess.PIPE, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + + flags = [] + arch = "None" + for line in iter(p.stdout.readline, ""): + if "[enabled]" in line: + flags.append(line.split()[0]) + if "-march" in line and len(line.split()) == 2: + arch = line.split()[1] + + p.wait() + p.stdout.close() + if p.returncode != 0: + raise Exception("g++ target query failed with return code %d" % (p.returncode)) - flags=[] - arch="None" - for line in iter(p.stdout.readline,''): - if '[enabled]' in line: - flags.append(line.split()[0]) - if '-march' in line and len(line.split()) == 2: - arch = line.split()[1] + return {"flags": flags, "arch": arch} - p.wait() - p.stdout.close() - if p.returncode != 0: - raise Exception('g++ target query failed with return code %d' % (p.returncode)) +def make_targets(): + """Parse the output of make help and extract the available targets""" + p = subprocess.Popen( + ["make", "help"], + stdout=subprocess.PIPE, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + + targets = [] + read_targets = False + + for line in iter(p.stdout.readline, ""): + if "Supported compilers:" in line: + read_targets = False + if read_targets and len(line.split()) > 1: + targets.append(line.split()[0]) + if "Supported archs:" in line: + read_targets = True - return {'flags' : flags, 'arch' : arch} + p.wait() + p.stdout.close() + if p.returncode != 0: + raise Exception("make help failed with return code %d" % (p.returncode)) -def make_targets(): - """Parse the output of make help and extract the available targets""" - p = subprocess.Popen(['make', 'help'], stdout=subprocess.PIPE, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) + return targets - targets=[] - read_targets = False - for line in iter(p.stdout.readline,''): - if 'Supported compilers:' in line: - read_targets = False - if read_targets and len(line.split())>1: - targets.append(line.split()[0]) - if 'Supported archs:' in line: - read_targets = True +def find_arch_string(): + """Find the best ARCH=... string based on the cpu/g++ capabilities and Makefile targets""" + + targets = make_targets() + + props = gcc_props() + + if is_64bit(): + if "-mavx512bw" in props["flags"] and "x86-64-avx512" in targets: + res = "x86-64-avx512" + elif ( + "-mbmi2" in props["flags"] + and "x86-64-bmi2" in targets + and not props["arch"] in ["znver1", "znver2"] + ): + res = "x86-64-bmi2" + elif "-mavx2" in props["flags"] and "x86-64-avx2" in targets: + res = "x86-64-avx2" + elif ( + "-mpopcnt" in props["flags"] + and "-msse4.1" in props["flags"] + and "x86-64-modern" in targets + ): + res = "x86-64-modern" + elif "-mssse3" in props["flags"] and "x86-64-ssse3" in targets: + res = "x86-64-ssse3" + elif ( + "-mpopcnt" in props["flags"] + and "-msse3" in props["flags"] + and "x86-64-sse3-popcnt" in targets + ): + res = "x86-64-sse3-popcnt" + else: + res = "x86-64" + else: + res = "x86-32" - p.wait() - p.stdout.close() + print("Available Makefile architecture targets: ", targets) + print("Available g++/cpu properties : ", props) + print("Determined the best architecture to be ", res) - if p.returncode != 0: - raise Exception('make help failed with return code %d' % (p.returncode)) + return "ARCH=" + res - return targets -def find_arch_string(): - """Find the best ARCH=... string based on the cpu/g++ capabilities and Makefile targets""" - - targets = make_targets() - - props = gcc_props() - - if is_64bit(): - if '-mavx512bw' in props['flags'] and 'x86-64-avx512' in targets: - res='x86-64-avx512' - elif '-mbmi2' in props['flags'] and 'x86-64-bmi2' in targets \ - and not props['arch'] in ['znver1', 'znver2']: - res='x86-64-bmi2' - elif '-mavx2' in props['flags'] and 'x86-64-avx2' in targets: - res='x86-64-avx2' - elif '-mpopcnt' in props['flags'] and '-msse4.1' in props['flags'] and 'x86-64-modern' in targets: - res='x86-64-modern' - elif '-mssse3' in props['flags'] and 'x86-64-ssse3' in targets: - res='x86-64-ssse3' - elif '-mpopcnt' in props['flags'] and '-msse3' in props['flags'] and 'x86-64-sse3-popcnt' in targets: - res='x86-64-sse3-popcnt' - else: - res='x86-64' - else: - res='x86-32' - - print("Available Makefile architecture targets: ", targets) - print("Available g++/cpu properties : ", props) - print("Determined the best architecture to be ", res) - - return 'ARCH=' + res - -def setup_engine(destination, worker_dir, testing_dir, remote, sha, repo_url, concurrency): - if os.path.exists(destination): os.remove(destination) - """Download and build sources in a temporary directory then move exe to destination""" - tmp_dir = tempfile.mkdtemp(dir=worker_dir) - - try: - os.chdir(tmp_dir) - with open('sf.gz', 'wb+') as f: - f.write(requests.get(github_api(repo_url) + '/zipball/' + sha, timeout=HTTP_TIMEOUT).content) - zip_file = ZipFile('sf.gz') - zip_file.extractall() - zip_file.close() - - for name in zip_file.namelist(): - if name.endswith('/src/'): - src_dir = name - os.chdir(src_dir) - - net = required_net_from_source() - if net: - print("Build uses default net: ", net) - if not os.path.exists(os.path.join(testing_dir, net)) or not validate_net(testing_dir, net): - download_net(remote, testing_dir, net) - if not validate_net(testing_dir, net): - raise Exception('Failed to validate the network: %s ' % (net)) - shutil.copyfile(os.path.join(testing_dir, net), net) - - ARCH = find_arch_string() - - subprocess.check_call(MAKE_CMD + ARCH + ' -j %s' % (concurrency) + ' profile-build', shell=True) - try: # try/pass needed for backwards compatibility with older stockfish, where 'make strip' fails under mingw. - subprocess.check_call(MAKE_CMD + ARCH + ' -j %s' % (concurrency) + ' strip', shell=True) +def setup_engine( + destination, worker_dir, testing_dir, remote, sha, repo_url, concurrency +): + if os.path.exists(destination): + os.remove(destination) + """Download and build sources in a temporary directory then move exe to destination""" + tmp_dir = tempfile.mkdtemp(dir=worker_dir) + + try: + os.chdir(tmp_dir) + with open("sf.gz", "wb+") as f: + f.write( + requests.get( + github_api(repo_url) + "/zipball/" + sha, timeout=HTTP_TIMEOUT + ).content + ) + zip_file = ZipFile("sf.gz") + zip_file.extractall() + zip_file.close() + + for name in zip_file.namelist(): + if name.endswith("/src/"): + src_dir = name + os.chdir(src_dir) + + net = required_net_from_source() + if net: + print("Build uses default net: ", net) + if not os.path.exists(os.path.join(testing_dir, net)) or not validate_net( + testing_dir, net + ): + download_net(remote, testing_dir, net) + if not validate_net(testing_dir, net): + raise Exception("Failed to validate the network: %s " % (net)) + shutil.copyfile(os.path.join(testing_dir, net), net) + + ARCH = find_arch_string() + + subprocess.check_call( + MAKE_CMD + ARCH + " -j %s" % (concurrency) + " profile-build", shell=True + ) + try: # try/pass needed for backwards compatibility with older stockfish, where 'make strip' fails under mingw. + subprocess.check_call( + MAKE_CMD + ARCH + " -j %s" % (concurrency) + " strip", shell=True + ) + except: + pass + + shutil.move("stockfish" + EXE_SUFFIX, destination) except: - pass + raise Exception("Failed to setup engine for %s" % (sha)) + finally: + os.chdir(worker_dir) + shutil.rmtree(tmp_dir) - shutil.move('stockfish'+ EXE_SUFFIX, destination) - except: - raise Exception('Failed to setup engine for %s' % (sha)) - finally: - os.chdir(worker_dir) - shutil.rmtree(tmp_dir) def kill_process(p): - try: - if IS_WINDOWS: - # Kill doesn't kill subprocesses on Windows - subprocess.call(['taskkill', '/F', '/T', '/PID', str(p.pid)]) - else: - p.kill() - except: - print('Note: ' + str(sys.exc_info()[0]) + ' killing the process pid: ' + str(p.pid) + ', possibly already terminated') - finally: - p.wait() - p.stdout.close() + try: + if IS_WINDOWS: + # Kill doesn't kill subprocesses on Windows + subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)]) + else: + p.kill() + except: + print( + "Note: " + + str(sys.exc_info()[0]) + + " killing the process pid: " + + str(p.pid) + + ", possibly already terminated" + ) + finally: + p.wait() + p.stdout.close() + def adjust_tc(tc, base_nps, concurrency): - factor = 1600000.0 / base_nps # 1.6Mnps is the reference core, also used in fishtest views. - if base_nps < 700000: - sys.stderr.write('This machine is too slow to run fishtest effectively - sorry!\n') - sys.exit(1) - - # Parse the time control in cutechess format - chunks = tc.split('+') - increment = 0.0 - if len(chunks) == 2: - increment = float(chunks[1]) - - chunks = chunks[0].split('/') - num_moves = 0 - if len(chunks) == 2: - num_moves = int(chunks[0]) - - time_tc = chunks[-1] - chunks = time_tc.split(':') - if len(chunks) == 2: - time_tc = float(chunks[0]) * 60 + float(chunks[1]) - else: - time_tc = float(chunks[0]) - - # Rebuild scaled_tc now: cutechess-cli and stockfish parse 3 decimal places - scaled_tc = '%.3f' % (time_tc * factor) - tc_limit = time_tc * factor * 3 - if increment > 0.0: - scaled_tc += '+%.3f' % (increment * factor) - tc_limit += increment * factor * 200 - if num_moves > 0: - scaled_tc = '%d/%s' % (num_moves, scaled_tc) - tc_limit *= 100.0 / num_moves - - print('CPU factor : %f - tc adjusted to %s' % (factor, scaled_tc)) - return scaled_tc, tc_limit + factor = ( + 1600000.0 / base_nps + ) # 1.6Mnps is the reference core, also used in fishtest views. + if base_nps < 700000: + sys.stderr.write( + "This machine is too slow to run fishtest effectively - sorry!\n" + ) + sys.exit(1) + + # Parse the time control in cutechess format + chunks = tc.split("+") + increment = 0.0 + if len(chunks) == 2: + increment = float(chunks[1]) + + chunks = chunks[0].split("/") + num_moves = 0 + if len(chunks) == 2: + num_moves = int(chunks[0]) + + time_tc = chunks[-1] + chunks = time_tc.split(":") + if len(chunks) == 2: + time_tc = float(chunks[0]) * 60 + float(chunks[1]) + else: + time_tc = float(chunks[0]) + + # Rebuild scaled_tc now: cutechess-cli and stockfish parse 3 decimal places + scaled_tc = "%.3f" % (time_tc * factor) + tc_limit = time_tc * factor * 3 + if increment > 0.0: + scaled_tc += "+%.3f" % (increment * factor) + tc_limit += increment * factor * 200 + if num_moves > 0: + scaled_tc = "%d/%s" % (num_moves, scaled_tc) + tc_limit *= 100.0 / num_moves + + print("CPU factor : %f - tc adjusted to %s" % (factor, scaled_tc)) + return scaled_tc, tc_limit + def enqueue_output(out, queue): - for line in iter(out.readline, ''): - queue.put(line) + for line in iter(out.readline, ""): + queue.put(line) w_params = None b_params = None -def update_pentanomial(line,rounds): - def result_to_score(_result): - if _result=="1-0": - return 2 - elif _result=="0-1": - return 0 - elif _result=="1/2-1/2": - return 1 - else: - return -1 - if not 'pentanomial' in rounds.keys(): - rounds['pentanomial']=5*[0] - if not 'trinomial' in rounds.keys(): - rounds['trinomial']=3*[0] - - saved_sum_trinomial=sum(rounds['trinomial']) - current={} - - # Parse line like this: - # Finished game 4 (Base-5446e6f vs New-1a68b26): 1/2-1/2 {Draw by adjudication} - line=line.split() - if line[0]=='Finished' and line[1]=='game' and len(line)>=7: - round_=int(line[2]) - rounds[round_]=current - current['white']=line[3][1:] - current['black']=line[5][:-2] - i=current['result']=result_to_score(line[6]) - if round_%2==0: - if i!=-1: - rounds['trinomial'][2-i]+=1 #reversed colors - odd=round_-1 - even=round_ - else: - if i!=-1: - rounds['trinomial'][i]+=1 - odd=round_ - even=round_+1 - if odd in rounds.keys() and even in rounds.keys(): - assert(rounds[odd]['white'][0:3]=='New') - assert(rounds[odd]['white']==rounds[even]['black']) - assert(rounds[odd]['black']==rounds[even]['white']) - i=rounds[odd]['result'] - j=rounds[even]['result'] # even is reversed colors - if i!=-1 and j!=-1: - rounds['pentanomial'][i+2-j]+=1 - del rounds[odd] - del rounds[even] - rounds['trinomial'][i]-=1 - rounds['trinomial'][2-j]-=1 - assert(rounds['trinomial'][i]>=0) - assert(rounds['trinomial'][2-j]>=0) - - # make sure something happened, but not too much - assert(current.get('result',-1000)==-1 or abs(sum(rounds['trinomial'])-saved_sum_trinomial)==1) +def update_pentanomial(line, rounds): + def result_to_score(_result): + if _result == "1-0": + return 2 + elif _result == "0-1": + return 0 + elif _result == "1/2-1/2": + return 1 + else: + return -1 -def validate_pentanomial(wld, rounds): - def results_to_score(results): - return sum([results[i] * (i / 2.0) for i in range(len(results))]) - LDW = [wld[1], wld[2], wld[0]] - s3 = results_to_score(LDW) - s5 = results_to_score(rounds['pentanomial']) + results_to_score(rounds['trinomial']) - assert(sum(LDW) == 2 * sum(rounds['pentanomial']) + sum(rounds['trinomial'])) - epsilon = 1e-4 - assert(abs(s5 - s3) < epsilon) + if not "pentanomial" in rounds.keys(): + rounds["pentanomial"] = 5 * [0] + if not "trinomial" in rounds.keys(): + rounds["trinomial"] = 3 * [0] + saved_sum_trinomial = sum(rounds["trinomial"]) + current = {} -def parse_cutechess_output(p, remote, result, spsa, spsa_tuning, games_to_play, batch_size, tc_limit): + # Parse line like this: + # Finished game 4 (Base-5446e6f vs New-1a68b26): 1/2-1/2 {Draw by adjudication} + line = line.split() + if line[0] == "Finished" and line[1] == "game" and len(line) >= 7: + round_ = int(line[2]) + rounds[round_] = current + current["white"] = line[3][1:] + current["black"] = line[5][:-2] + i = current["result"] = result_to_score(line[6]) + if round_ % 2 == 0: + if i != -1: + rounds["trinomial"][2 - i] += 1 # reversed colors + odd = round_ - 1 + even = round_ + else: + if i != -1: + rounds["trinomial"][i] += 1 + odd = round_ + even = round_ + 1 + if odd in rounds.keys() and even in rounds.keys(): + assert rounds[odd]["white"][0:3] == "New" + assert rounds[odd]["white"] == rounds[even]["black"] + assert rounds[odd]["black"] == rounds[even]["white"] + i = rounds[odd]["result"] + j = rounds[even]["result"] # even is reversed colors + if i != -1 and j != -1: + rounds["pentanomial"][i + 2 - j] += 1 + del rounds[odd] + del rounds[even] + rounds["trinomial"][i] -= 1 + rounds["trinomial"][2 - j] -= 1 + assert rounds["trinomial"][i] >= 0 + assert rounds["trinomial"][2 - j] >= 0 + + # make sure something happened, but not too much + assert ( + current.get("result", -1000) == -1 + or abs(sum(rounds["trinomial"]) - saved_sum_trinomial) == 1 + ) - saved_stats=copy.deepcopy(result['stats']) - rounds = {} - q = Queue() - t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) - t.daemon = True - t.start() +def validate_pentanomial(wld, rounds): + def results_to_score(results): + return sum([results[i] * (i / 2.0) for i in range(len(results))]) + + LDW = [wld[1], wld[2], wld[0]] + s3 = results_to_score(LDW) + s5 = results_to_score(rounds["pentanomial"]) + results_to_score(rounds["trinomial"]) + assert sum(LDW) == 2 * sum(rounds["pentanomial"]) + sum(rounds["trinomial"]) + epsilon = 1e-4 + assert abs(s5 - s3) < epsilon + + +def parse_cutechess_output( + p, remote, result, spsa, spsa_tuning, games_to_play, batch_size, tc_limit +): + + saved_stats = copy.deepcopy(result["stats"]) + rounds = {} + + q = Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + + end_time = datetime.datetime.now() + datetime.timedelta(seconds=tc_limit) + print("TC limit {} End time: {}".format(tc_limit, end_time)) + + num_games_updated = 0 + while datetime.datetime.now() < end_time: + try: + line = q.get_nowait() + except Empty: + if p.poll() is not None: + break + time.sleep(1) + continue + + sys.stdout.write(line) + sys.stdout.flush() + + # Have we reached the end of the match? Then just exit + if "Finished match" in line: + # The following assertion will fail if there are games without result. + # Does this ever happen? + assert num_games_updated == games_to_play + print("Finished match cleanly") + + # Parse line like this: + # Warning: New-eb6a21875e doesn't have option ThreatBySafePawn + if "Warning:" in line and "doesn't have option" in line: + message = r'Cutechess-cli says: "%s"' % line.strip() + result["message"] = message + send_api_post_request(remote + "/api/stop_run", result) + raise Exception(message) + + # Parse line like this: + # Finished game 1 (stockfish vs base): 0-1 {White disconnects} + if "disconnects" in line or "connection stalls" in line: + result["stats"]["crashes"] += 1 + + if "on time" in line: + result["stats"]["time_losses"] += 1 + + # Parse line like this: + # Score of stockfish vs base: 0 - 0 - 1 [0.500] 1 + if "Score" in line: + chunks = line.split(":") + chunks = chunks[1].split() + wld = [int(chunks[0]), int(chunks[2]), int(chunks[4])] + + validate_pentanomial( + wld, rounds + ) # check if cutechess-cli result is compatible with + # our own bookkeeping + + pentanomial = [ + rounds["pentanomial"][i] + saved_stats["pentanomial"][i] + for i in range(5) + ] + result["stats"]["pentanomial"] = pentanomial + + wld_pairs = {} # trinomial frequencies of completed game pairs + + # rounds['trinomial'] is ordered ldw + wld_pairs["wins"] = wld[0] - rounds["trinomial"][2] + wld_pairs["losses"] = wld[1] - rounds["trinomial"][0] + wld_pairs["draws"] = wld[2] - rounds["trinomial"][1] + + result["stats"]["wins"] = wld_pairs["wins"] + saved_stats["wins"] + result["stats"]["losses"] = wld_pairs["losses"] + saved_stats["losses"] + result["stats"]["draws"] = wld_pairs["draws"] + saved_stats["draws"] + + if spsa_tuning: + spsa["wins"] = wld_pairs["wins"] + spsa["losses"] = wld_pairs["losses"] + spsa["draws"] = wld_pairs["draws"] + + num_games_finished = ( + wld_pairs["wins"] + wld_pairs["losses"] + wld_pairs["draws"] + ) + + assert ( + 2 * sum(result["stats"]["pentanomial"]) + == result["stats"]["wins"] + + result["stats"]["losses"] + + result["stats"]["draws"] + ) + assert num_games_finished == 2 * sum(rounds["pentanomial"]) + assert num_games_finished <= num_games_updated + batch_size + assert num_games_finished <= games_to_play + + # Send an update_task request after a batch is full or if we have played all games + if (num_games_finished == num_games_updated + batch_size) or ( + num_games_finished == games_to_play + ): + # Attempt to send game results to the server. Retry a few times upon error + update_succeeded = False + for _ in range(5): + try: + t0 = datetime.datetime.utcnow() + response = send_api_post_request( + remote + "/api/update_task", result + ).json() + print( + " Task updated successfully in %ss" + % ((datetime.datetime.utcnow() - t0).total_seconds()) + ) + if not response["task_alive"]: + # This task is no longer necessary + print("Server told us task is no longer needed") + return response + update_succeeded = True + num_games_updated = num_games_finished + break + except Exception as e: + sys.stderr.write("Exception from calling update_task:\n") + print(e) + # traceback.print_exc(file=sys.stderr) + time.sleep(HTTP_TIMEOUT) + if not update_succeeded: + print("Too many failed update attempts") + break + + # act on line like this + # Finished game 4 (Base-5446e6f vs New-1a68b26): 1/2-1/2 {Draw by adjudication} + if "Finished game" in line: + update_pentanomial(line, rounds) + + now = datetime.datetime.now() + if now >= end_time: + print("{} is past end time {}".format(now, end_time)) + + return {"task_alive": True} + + +def launch_cutechess( + cmd, remote, result, spsa_tuning, games_to_play, batch_size, tc_limit +): + spsa = {"w_params": [], "b_params": [], "num_games": games_to_play} - end_time = datetime.datetime.now() + datetime.timedelta(seconds=tc_limit) - print('TC limit {} End time: {}'.format(tc_limit, end_time)) + if spsa_tuning: + # Request parameters for next game + t0 = datetime.datetime.utcnow() + req = send_api_post_request(remote + "/api/request_spsa", result).json() + print( + "Fetched SPSA parameters successfully in %ss" + % ((datetime.datetime.utcnow() - t0).total_seconds()) + ) + + global w_params, b_params + w_params = req["w_params"] + b_params = req["b_params"] + + result["spsa"] = spsa + else: + w_params = [] + b_params = [] - num_games_updated = 0 - while datetime.datetime.now() < end_time: + # Run cutechess-cli binary + idx = cmd.index("_spsa_") + cmd = ( + cmd[:idx] + + ["option.%s=%d" % (x["name"], round(x["value"])) for x in w_params] + + cmd[idx + 1 :] + ) + idx = cmd.index("_spsa_") + cmd = ( + cmd[:idx] + + ["option.%s=%d" % (x["name"], round(x["value"])) for x in b_params] + + cmd[idx + 1 :] + ) + + print(cmd) + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + close_fds=not IS_WINDOWS, + ) + + task_state = {"task_alive": False} try: - line = q.get_nowait() - except Empty: - if p.poll() is not None: - break - time.sleep(1) - continue - - sys.stdout.write(line) - sys.stdout.flush() - - # Have we reached the end of the match? Then just exit - if 'Finished match' in line: - # The following assertion will fail if there are games without result. - # Does this ever happen? - assert(num_games_updated==games_to_play) - print('Finished match cleanly') - - # Parse line like this: - # Warning: New-eb6a21875e doesn't have option ThreatBySafePawn - if 'Warning:' in line and "doesn't have option" in line: - message = r'Cutechess-cli says: "%s"' % line.strip() - result['message']=message - send_api_post_request(remote + '/api/stop_run', result) - raise Exception(message) - - # Parse line like this: - # Finished game 1 (stockfish vs base): 0-1 {White disconnects} - if 'disconnects' in line or 'connection stalls' in line: - result['stats']['crashes'] += 1 + task_state = parse_cutechess_output( + p, remote, result, spsa, spsa_tuning, games_to_play, batch_size, tc_limit + ) + except Exception as e: + print("Exception running games") + traceback.print_exc(file=sys.stderr) - if 'on time' in line: - result['stats']['time_losses'] += 1 + kill_process(p) + return task_state - # Parse line like this: - # Score of stockfish vs base: 0 - 0 - 1 [0.500] 1 - if 'Score' in line: - chunks = line.split(':') - chunks = chunks[1].split() - wld = [int(chunks[0]), int(chunks[2]), int(chunks[4])] - - validate_pentanomial(wld, rounds) # check if cutechess-cli result is compatible with - # our own bookkeeping - - pentanomial=[rounds['pentanomial'][i]+saved_stats['pentanomial'][i] for i in range(5)] - result['stats']['pentanomial'] = pentanomial - - wld_pairs={} # trinomial frequencies of completed game pairs - - # rounds['trinomial'] is ordered ldw - wld_pairs['wins'] = wld[0] - rounds['trinomial'][2] - wld_pairs['losses'] = wld[1] - rounds['trinomial'][0] - wld_pairs['draws'] = wld[2] - rounds['trinomial'][1] - - result['stats']['wins'] = wld_pairs['wins'] + saved_stats['wins'] - result['stats']['losses'] = wld_pairs['losses'] + saved_stats['losses'] - result['stats']['draws'] = wld_pairs['draws'] + saved_stats['draws'] - - if spsa_tuning: - spsa['wins'] = wld_pairs['wins'] - spsa['losses'] = wld_pairs['losses'] - spsa['draws'] = wld_pairs['draws'] - - num_games_finished=wld_pairs['wins']+wld_pairs['losses']+wld_pairs['draws'] - - assert(2*sum(result['stats']['pentanomial'])==result['stats']['wins']+result['stats']['losses']+result['stats']['draws']) - assert(num_games_finished==2*sum(rounds['pentanomial'])) - assert(num_games_finished <= num_games_updated+batch_size) - assert(num_games_finished <= games_to_play) - - # Send an update_task request after a batch is full or if we have played all games - if (num_games_finished == num_games_updated+batch_size) or (num_games_finished==games_to_play): - # Attempt to send game results to the server. Retry a few times upon error - update_succeeded = False - for _ in range(5): - try: - t0 = datetime.datetime.utcnow() - response = send_api_post_request(remote + '/api/update_task', result).json() - print(" Task updated successfully in %ss" % ((datetime.datetime.utcnow() - t0).total_seconds())) - if not response['task_alive']: - # This task is no longer necessary - print('Server told us task is no longer needed') - return response - update_succeeded = True - num_games_updated = num_games_finished - break - except Exception as e: - sys.stderr.write('Exception from calling update_task:\n') - print(e) - # traceback.print_exc(file=sys.stderr) - time.sleep(HTTP_TIMEOUT) - if not update_succeeded: - print('Too many failed update attempts') - break - - # act on line like this - # Finished game 4 (Base-5446e6f vs New-1a68b26): 1/2-1/2 {Draw by adjudication} - if 'Finished game' in line: - update_pentanomial(line, rounds) - - now = datetime.datetime.now() - if now >= end_time: - print('{} is past end time {}'.format(now, end_time)) - - return { 'task_alive': True } - -def launch_cutechess(cmd, remote, result, spsa_tuning, games_to_play, batch_size, tc_limit): - spsa = { - 'w_params': [], - 'b_params': [], - 'num_games': games_to_play, - } - - if spsa_tuning: - # Request parameters for next game - t0 = datetime.datetime.utcnow() - req = send_api_post_request(remote + '/api/request_spsa', result).json() - print("Fetched SPSA parameters successfully in %ss" % ((datetime.datetime.utcnow() - t0).total_seconds())) - - global w_params, b_params - w_params = req['w_params'] - b_params = req['b_params'] - - result['spsa'] = spsa - else: - w_params = [] - b_params = [] - - # Run cutechess-cli binary - idx = cmd.index('_spsa_') - cmd = cmd[:idx] + ['option.%s=%d'%(x['name'], round(x['value'])) for x in w_params] + cmd[idx+1:] - idx = cmd.index('_spsa_') - cmd = cmd[:idx] + ['option.%s=%d'%(x['name'], round(x['value'])) for x in b_params] + cmd[idx+1:] - - print(cmd) - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, bufsize=1, close_fds=not IS_WINDOWS) - - task_state = { 'task_alive': False } - try: - task_state = parse_cutechess_output(p, remote, result, spsa, spsa_tuning, games_to_play, batch_size, tc_limit) - except Exception as e: - print('Exception running games') - traceback.print_exc(file=sys.stderr) - - kill_process(p) - return task_state def run_games(worker_info, password, remote, run, task_id): - task = run['my_task'] - - # Have we run any games on this task yet? - - input_stats = task.get('stats', {'wins':0, 'losses':0, 'draws':0, 'crashes':0, 'time_losses':0, 'pentanomial':5*[0]}) - if not 'pentanomial' in input_stats: - input_stats['pentanomial']=5*[0] - - assert(2*sum(input_stats['pentanomial'])==input_stats['wins']+input_stats['losses']+input_stats['draws']) - - input_stats['crashes']=input_stats.get('crashes', 0) - input_stats['time_losses']=input_stats.get('time_losses', 0) - - result = { - 'username': worker_info['username'], - 'password': password, - 'run_id': str(run['_id']), - 'task_id': task_id, - 'stats': input_stats - } - - games_remaining = task['num_games'] - (input_stats['wins'] + input_stats['losses'] + input_stats['draws']) - - assert(games_remaining>0) - assert(games_remaining%2==0) - - book = run['args']['book'] - book_depth = run['args']['book_depth'] - new_options = run['args']['new_options'] - base_options = run['args']['base_options'] - threads = int(run['args']['threads']) - spsa_tuning = 'spsa' in run['args'] - repo_url = run['args'].get('tests_repo', REPO_URL) - games_concurrency = int(worker_info['concurrency']) // threads - - # Format options according to cutechess syntax - def parse_options(s): - results = [] - chunks = s.split('=') - if len(chunks) == 0: - return results - param = chunks[0] - for c in chunks[1:]: - val = c.split() - results.append('option.%s=%s' % (param, val[0])) - param = ' '.join(val[1:]) - return results - - new_options = parse_options(new_options) - base_options = parse_options(base_options) - - # Setup testing directory if not already exsisting - worker_dir = os.path.dirname(os.path.realpath(__file__)) - testing_dir = os.path.join(worker_dir, 'testing') - if not os.path.exists(testing_dir): - os.makedirs(testing_dir) - - # clean up old engines (keeping the 50 most recent) - engines = glob.glob(os.path.join(testing_dir, 'stockfish_*' + EXE_SUFFIX)) - if len(engines) > 50: - engines.sort(key=os.path.getmtime) - for old_engine in engines[:-50]: - try: - os.remove(old_engine) - except: - print('Note: failed to remove an old engine binary ' + str(old_engine)) - pass - - # create new engines - sha_new = run['args']['resolved_new'] - sha_base = run['args']['resolved_base'] - new_engine_name = 'stockfish_' + sha_new - base_engine_name = 'stockfish_' + sha_base - - new_engine = os.path.join(testing_dir, new_engine_name + EXE_SUFFIX) - base_engine = os.path.join(testing_dir, base_engine_name + EXE_SUFFIX) - cutechess = os.path.join(testing_dir, 'cutechess-cli' + EXE_SUFFIX) - - # Build from sources new and base engines as needed - if not os.path.exists(new_engine): - setup_engine(new_engine, worker_dir, testing_dir, remote, sha_new, repo_url, worker_info['concurrency']) - if not os.path.exists(base_engine): - setup_engine(base_engine, worker_dir, testing_dir, remote, sha_base, repo_url, worker_info['concurrency']) - - os.chdir(testing_dir) - - # Download book if not already existing - if not os.path.exists(os.path.join(testing_dir, book)) or os.stat(os.path.join(testing_dir, book)).st_size == 0: - zipball = book + '.zip' - setup(zipball, testing_dir) - zip_file = ZipFile(zipball) - zip_file.extractall() - zip_file.close() - os.remove(zipball) - - # Download cutechess if not already existing - if not os.path.exists(cutechess): - if len(EXE_SUFFIX) > 0: zipball = 'cutechess-cli-win.zip' - else: zipball = 'cutechess-cli-linux-%s.zip' % (platform.architecture()[0]) - setup(zipball, testing_dir) - zip_file = ZipFile(zipball) - zip_file.extractall() - zip_file.close() - os.remove(zipball) - os.chmod(cutechess, os.stat(cutechess).st_mode | stat.S_IEXEC) - - # clean up old networks (keeping the 20 most recent) - networks = glob.glob(os.path.join(testing_dir, 'nn-*.nnue')) - if len(networks) > 20: - networks.sort(key=os.path.getmtime) - for old_net in networks[:-20]: - try: - os.remove(old_net) - except: - print('Note: failed to remove an old network ' + str(old_net)) - pass - - # Add EvalFile with full path to cutechess options, and download networks if not already existing - net_base = required_net(base_engine) - if net_base: - base_options = base_options + ["option.EvalFile=%s"%(os.path.join(testing_dir, net_base))] - net_new = required_net(new_engine) - if net_new: - new_options = new_options + ["option.EvalFile=%s"%(os.path.join(testing_dir, net_new))] - - for net in [net_base, net_new]: - if net: - if not os.path.exists(os.path.join(testing_dir, net)) or not validate_net(testing_dir, net): - download_net(remote, testing_dir, net) - if not validate_net(testing_dir, net): - raise Exception('Failed to validate the network: %s ' % (net)) - - # pgn output setup - pgn_name = 'results-' + worker_info['unique_key'] + '.pgn' - if os.path.exists(pgn_name): - os.remove(pgn_name) - pgnfile = os.path.join(testing_dir, pgn_name) - - # Verify signatures are correct - verify_signature(new_engine, run['args']['new_signature'], remote, result, games_concurrency * threads) - base_nps = verify_signature(base_engine, run['args']['base_signature'], remote, result, games_concurrency * threads) - - # Benchmark to adjust cpu scaling - scaled_tc, tc_limit = adjust_tc(run['args']['tc'], base_nps, int(worker_info['concurrency'])) - result['nps'] = base_nps - result['ARCH'] = ARCH - - # Handle book or pgn file - pgn_cmd = [] - book_cmd = [] - if int(book_depth) <= 0: - pass - elif book.endswith('.pgn') or book.endswith('.epd'): - plies = 2 * int(book_depth) - pgn_cmd = ['-openings', 'file=%s' % (book), 'format=%s' % (book[-3:]), 'order=random', 'plies=%d' % (plies)] - else: - book_cmd = ['book=%s' % (book), 'bookdepth=%s' % (book_depth)] - - print('Running %s vs %s' % (run['args']['new_tag'], run['args']['base_tag'])) - - threads_cmd=[] - if not any("Threads" in s for s in new_options + base_options): - threads_cmd = ['option.Threads=%d' % (threads)] - - # If nodestime is being used, give engines extra grace time to - # make time losses virtually impossible - nodestime_cmd=[] - if any ("nodestime" in s for s in new_options + base_options): - nodestime_cmd = ['timemargin=10000'] - - def make_player(arg): - return run['args'][arg].split(' ')[0] - - if spsa_tuning: - tc_limit *= 2 - - while games_remaining > 0: - if spsa_tuning: - games_to_play = min(games_concurrency * 2, games_remaining) - pgnout = [] + task = run["my_task"] + + # Have we run any games on this task yet? + + input_stats = task.get( + "stats", + { + "wins": 0, + "losses": 0, + "draws": 0, + "crashes": 0, + "time_losses": 0, + "pentanomial": 5 * [0], + }, + ) + if not "pentanomial" in input_stats: + input_stats["pentanomial"] = 5 * [0] + + assert ( + 2 * sum(input_stats["pentanomial"]) + == input_stats["wins"] + input_stats["losses"] + input_stats["draws"] + ) + + input_stats["crashes"] = input_stats.get("crashes", 0) + input_stats["time_losses"] = input_stats.get("time_losses", 0) + + result = { + "username": worker_info["username"], + "password": password, + "run_id": str(run["_id"]), + "task_id": task_id, + "stats": input_stats, + } + + games_remaining = task["num_games"] - ( + input_stats["wins"] + input_stats["losses"] + input_stats["draws"] + ) + + assert games_remaining > 0 + assert games_remaining % 2 == 0 + + book = run["args"]["book"] + book_depth = run["args"]["book_depth"] + new_options = run["args"]["new_options"] + base_options = run["args"]["base_options"] + threads = int(run["args"]["threads"]) + spsa_tuning = "spsa" in run["args"] + repo_url = run["args"].get("tests_repo", REPO_URL) + games_concurrency = int(worker_info["concurrency"]) // threads + + # Format options according to cutechess syntax + def parse_options(s): + results = [] + chunks = s.split("=") + if len(chunks) == 0: + return results + param = chunks[0] + for c in chunks[1:]: + val = c.split() + results.append("option.%s=%s" % (param, val[0])) + param = " ".join(val[1:]) + return results + + new_options = parse_options(new_options) + base_options = parse_options(base_options) + + # Setup testing directory if not already exsisting + worker_dir = os.path.dirname(os.path.realpath(__file__)) + testing_dir = os.path.join(worker_dir, "testing") + if not os.path.exists(testing_dir): + os.makedirs(testing_dir) + + # clean up old engines (keeping the 50 most recent) + engines = glob.glob(os.path.join(testing_dir, "stockfish_*" + EXE_SUFFIX)) + if len(engines) > 50: + engines.sort(key=os.path.getmtime) + for old_engine in engines[:-50]: + try: + os.remove(old_engine) + except: + print("Note: failed to remove an old engine binary " + str(old_engine)) + pass + + # create new engines + sha_new = run["args"]["resolved_new"] + sha_base = run["args"]["resolved_base"] + new_engine_name = "stockfish_" + sha_new + base_engine_name = "stockfish_" + sha_base + + new_engine = os.path.join(testing_dir, new_engine_name + EXE_SUFFIX) + base_engine = os.path.join(testing_dir, base_engine_name + EXE_SUFFIX) + cutechess = os.path.join(testing_dir, "cutechess-cli" + EXE_SUFFIX) + + # Build from sources new and base engines as needed + if not os.path.exists(new_engine): + setup_engine( + new_engine, + worker_dir, + testing_dir, + remote, + sha_new, + repo_url, + worker_info["concurrency"], + ) + if not os.path.exists(base_engine): + setup_engine( + base_engine, + worker_dir, + testing_dir, + remote, + sha_base, + repo_url, + worker_info["concurrency"], + ) + + os.chdir(testing_dir) + + # Download book if not already existing + if ( + not os.path.exists(os.path.join(testing_dir, book)) + or os.stat(os.path.join(testing_dir, book)).st_size == 0 + ): + zipball = book + ".zip" + setup(zipball, testing_dir) + zip_file = ZipFile(zipball) + zip_file.extractall() + zip_file.close() + os.remove(zipball) + + # Download cutechess if not already existing + if not os.path.exists(cutechess): + if len(EXE_SUFFIX) > 0: + zipball = "cutechess-cli-win.zip" + else: + zipball = "cutechess-cli-linux-%s.zip" % (platform.architecture()[0]) + setup(zipball, testing_dir) + zip_file = ZipFile(zipball) + zip_file.extractall() + zip_file.close() + os.remove(zipball) + os.chmod(cutechess, os.stat(cutechess).st_mode | stat.S_IEXEC) + + # clean up old networks (keeping the 20 most recent) + networks = glob.glob(os.path.join(testing_dir, "nn-*.nnue")) + if len(networks) > 20: + networks.sort(key=os.path.getmtime) + for old_net in networks[:-20]: + try: + os.remove(old_net) + except: + print("Note: failed to remove an old network " + str(old_net)) + pass + + # Add EvalFile with full path to cutechess options, and download networks if not already existing + net_base = required_net(base_engine) + if net_base: + base_options = base_options + [ + "option.EvalFile=%s" % (os.path.join(testing_dir, net_base)) + ] + net_new = required_net(new_engine) + if net_new: + new_options = new_options + [ + "option.EvalFile=%s" % (os.path.join(testing_dir, net_new)) + ] + + for net in [net_base, net_new]: + if net: + if not os.path.exists(os.path.join(testing_dir, net)) or not validate_net( + testing_dir, net + ): + download_net(remote, testing_dir, net) + if not validate_net(testing_dir, net): + raise Exception("Failed to validate the network: %s " % (net)) + + # pgn output setup + pgn_name = "results-" + worker_info["unique_key"] + ".pgn" + if os.path.exists(pgn_name): + os.remove(pgn_name) + pgnfile = os.path.join(testing_dir, pgn_name) + + # Verify signatures are correct + verify_signature( + new_engine, + run["args"]["new_signature"], + remote, + result, + games_concurrency * threads, + ) + base_nps = verify_signature( + base_engine, + run["args"]["base_signature"], + remote, + result, + games_concurrency * threads, + ) + + # Benchmark to adjust cpu scaling + scaled_tc, tc_limit = adjust_tc( + run["args"]["tc"], base_nps, int(worker_info["concurrency"]) + ) + result["nps"] = base_nps + result["ARCH"] = ARCH + + # Handle book or pgn file + pgn_cmd = [] + book_cmd = [] + if int(book_depth) <= 0: + pass + elif book.endswith(".pgn") or book.endswith(".epd"): + plies = 2 * int(book_depth) + pgn_cmd = [ + "-openings", + "file=%s" % (book), + "format=%s" % (book[-3:]), + "order=random", + "plies=%d" % (plies), + ] else: - games_to_play = games_remaining - pgnout = ['-pgnout', pgn_name] + book_cmd = ["book=%s" % (book), "bookdepth=%s" % (book_depth)] - batch_size = games_concurrency * 2 # update frequency + print("Running %s vs %s" % (run["args"]["new_tag"], run["args"]["base_tag"])) - if 'sprt' in run['args']: - batch_size = 2 * run['args']['sprt'].get('batch_size',1) - assert(games_to_play%batch_size==0) + threads_cmd = [] + if not any("Threads" in s for s in new_options + base_options): + threads_cmd = ["option.Threads=%d" % (threads)] - assert(batch_size%2==0) - assert(games_to_play%2==0) + # If nodestime is being used, give engines extra grace time to + # make time losses virtually impossible + nodestime_cmd = [] + if any("nodestime" in s for s in new_options + base_options): + nodestime_cmd = ["timemargin=10000"] - # Run cutechess-cli binary - cmd = [ cutechess, '-repeat', '-rounds', str(int(games_to_play/2)), '-games', ' 2', '-tournament', 'gauntlet'] + pgnout + \ - ['-site', 'https://tests.stockfishchess.org/tests/view/' + run['_id']] + \ - ['-event', 'Batch %d: %s vs %s' % (task_id, make_player('new_tag'), make_player('base_tag'))] + \ - ['-srand', "%d" % struct.unpack(" 0: + if spsa_tuning: + games_to_play = min(games_concurrency * 2, games_remaining) + pgnout = [] + else: + games_to_play = games_remaining + pgnout = ["-pgnout", pgn_name] + + batch_size = games_concurrency * 2 # update frequency + + if "sprt" in run["args"]: + batch_size = 2 * run["args"]["sprt"].get("batch_size", 1) + assert games_to_play % batch_size == 0 + + assert batch_size % 2 == 0 + assert games_to_play % 2 == 0 + + # Run cutechess-cli binary + cmd = ( + [ + cutechess, + "-repeat", + "-rounds", + str(int(games_to_play / 2)), + "-games", + " 2", + "-tournament", + "gauntlet", + ] + + pgnout + + ["-site", "https://tests.stockfishchess.org/tests/view/" + run["_id"]] + + [ + "-event", + "Batch %d: %s vs %s" + % (task_id, make_player("new_tag"), make_player("base_tag")), + ] + + ["-srand", "%d" % struct.unpack(" num_bkps: - bkp_dirs.sort(key=os.path.getmtime) - for old_bkp_dir in bkp_dirs[:-num_bkps]: + worker_dir = os.path.dirname(os.path.realpath(__file__)) + update_dir = os.path.join(worker_dir, "update") + if not os.path.exists(update_dir): + os.makedirs(update_dir) + + worker_zip = os.path.join(update_dir, "wk.zip") + with open(worker_zip, "wb+") as f: + f.write(requests.get(WORKER_URL).content) + + zip_file = ZipFile(worker_zip) + zip_file.extractall(update_dir) + zip_file.close() + prefix = os.path.commonprefix([n.filename for n in zip_file.infolist()]) + worker_src = os.path.join(update_dir, prefix + "worker") + if not test: + copy_tree(worker_src, worker_dir) + else: + file_list = os.listdir(worker_src) + shutil.rmtree(update_dir) + + # rename the testing_dir as backup + # and to trigger the download of update files + testing_dir = os.path.join(worker_dir, "testing") + if os.path.exists(testing_dir): try: - shutil.rmtree(old_bkp_dir) - except: - print('Note: failed to remove the old backup folder ' + str(old_bkp_dir)) - pass + time_stamp = str(datetime.datetime.timestamp(datetime.datetime.utcnow())) + except AttributeError: # python2 + dt_utc = datetime.datetime.utcnow() + time_stamp = str(time.mktime(dt_utc.timetuple()) + dt_utc.microsecond / 1e6) + + bkp_testing_dir = os.path.join(worker_dir, "_testing_" + time_stamp) + shutil.move(testing_dir, bkp_testing_dir) + os.makedirs(testing_dir) + # delete the old engine binaries + engines = glob.glob(os.path.join(bkp_testing_dir, "stockfish_*")) + for engine in engines: + try: + os.remove(engine) + except: + print("Note: failed to delete an engine binary " + str(engine)) + pass + # clean up old folder backups (keeping the num_bkps most recent) + bkp_dirs = glob.glob(os.path.join(worker_dir, "_testing_*")) + num_bkps = 3 + if len(bkp_dirs) > num_bkps: + bkp_dirs.sort(key=os.path.getmtime) + for old_bkp_dir in bkp_dirs[:-num_bkps]: + try: + shutil.rmtree(old_bkp_dir) + except: + print( + "Note: failed to remove the old backup folder " + + str(old_bkp_dir) + ) + pass + + print("start_dir: " + start_dir) + if restart: + do_restart() - print("start_dir: " + start_dir) - if restart: - do_restart() + if test: + return file_list - if test: - return file_list -if __name__ == '__main__': - update(False) +if __name__ == "__main__": + update(False) diff --git a/worker/worker.py b/worker/worker.py index 242ae0278..5a9501665 100644 --- a/worker/worker.py +++ b/worker/worker.py @@ -14,12 +14,15 @@ import time import traceback import uuid + try: - from ConfigParser import SafeConfigParser - config = SafeConfigParser() + from ConfigParser import SafeConfigParser + + config = SafeConfigParser() except ImportError: - from configparser import ConfigParser # Python3 - config = ConfigParser() + from configparser import ConfigParser # Python3 + + config = ConfigParser() import zlib import base64 from optparse import OptionParser @@ -34,255 +37,309 @@ def setup_config_file(config_file): - ''' Config file setup, adds defaults if not existing ''' - config.read(config_file) - - mem = 0 - system_type = platform.system().lower() - try: - if 'linux' in system_type: - cmd = 'free -b' - elif 'windows' in system_type: - cmd = 'wmic computersystem get TotalPhysicalMemory' - elif 'darwin' in system_type: - cmd = 'sysctl hw.memsize' - else: - cmd = '' - print('Unknown system') - with os.popen(cmd) as proc: - mem_str = str(proc.readlines()) - mem = int(re.search(r'\d+', mem_str).group()) - print('Memory: ' + str(mem)) - except: - traceback.print_exc() - pass - - defaults = [('login', 'username', ''), ('login', 'password', ''), - ('parameters', 'protocol', 'https'), - ('parameters', 'host', 'tests.stockfishchess.org'), - ('parameters', 'port', '443'), - ('parameters', 'concurrency', '3'), - ('parameters', 'max_memory', str(int(mem / 2 / 1024 / 1024))), - ('parameters', 'min_threads', '1'), - ] - - for v in defaults: - if not config.has_section(v[0]): - config.add_section(v[0]) - if not config.has_option(v[0], v[1]): - config.set(*v) - with open(config_file, 'w') as f: - config.write(f) + """ Config file setup, adds defaults if not existing """ + config.read(config_file) + + mem = 0 + system_type = platform.system().lower() + try: + if "linux" in system_type: + cmd = "free -b" + elif "windows" in system_type: + cmd = "wmic computersystem get TotalPhysicalMemory" + elif "darwin" in system_type: + cmd = "sysctl hw.memsize" + else: + cmd = "" + print("Unknown system") + with os.popen(cmd) as proc: + mem_str = str(proc.readlines()) + mem = int(re.search(r"\d+", mem_str).group()) + print("Memory: " + str(mem)) + except: + traceback.print_exc() + pass + + defaults = [ + ("login", "username", ""), + ("login", "password", ""), + ("parameters", "protocol", "https"), + ("parameters", "host", "tests.stockfishchess.org"), + ("parameters", "port", "443"), + ("parameters", "concurrency", "3"), + ("parameters", "max_memory", str(int(mem / 2 / 1024 / 1024))), + ("parameters", "min_threads", "1"), + ] + + for v in defaults: + if not config.has_section(v[0]): + config.add_section(v[0]) + if not config.has_option(v[0], v[1]): + config.set(*v) + with open(config_file, "w") as f: + config.write(f) + + return config - return config def on_sigint(signal, frame): - global ALIVE - ALIVE = False - raise Exception('Terminated by signal') + global ALIVE + ALIVE = False + raise Exception("Terminated by signal") + rate = None + def get_rate(): - global rate - try: - rate = requests.get('https://api.github.com/rate_limit', timeout=HTTP_TIMEOUT).json()['rate'] - except Exception as e: - sys.stderr.write('Exception fetching rate_limit:\n') - print(e) - rate = { 'remaining': 0, 'limit': 5000 } - return True - remaining = rate['remaining'] - print("API call rate limits:", rate) - return remaining >= math.sqrt(rate['limit']) + global rate + try: + rate = requests.get( + "https://api.github.com/rate_limit", timeout=HTTP_TIMEOUT + ).json()["rate"] + except Exception as e: + sys.stderr.write("Exception fetching rate_limit:\n") + print(e) + rate = {"remaining": 0, "limit": 5000} + return True + remaining = rate["remaining"] + print("API call rate limits:", rate) + return remaining >= math.sqrt(rate["limit"]) + def worker(worker_info, password, remote): - global ALIVE - - payload = { - 'worker_info': worker_info, - 'password': password, - } - - try: - print('Fetch task...') - if not get_rate(): - raise Exception('Near API limit') - - t0 = datetime.utcnow() - req = requests.post(remote + '/api/request_version', data=json.dumps(payload), headers={'Content-type': 'application/json'}, timeout=HTTP_TIMEOUT) - req = json.loads(req.text) - - if 'version' not in req: - print('Incorrect username/password') - time.sleep(5) - sys.exit(1) - - if req['version'] > WORKER_VERSION: - print('Updating worker version to %s' % (req['version'])) - update() - print("Worker version checked successfully in %ss" % ((datetime.utcnow() - t0).total_seconds())) - - t0 = datetime.utcnow() - worker_info['rate'] = rate - req = requests.post(remote + '/api/request_task', - data=json.dumps(payload), - headers={'Content-type': 'application/json'}, - timeout=HTTP_TIMEOUT) - req = json.loads(req.text) - except Exception as e: - sys.stderr.write('Exception accessing host:\n') - print(e) -# traceback.print_exc() - time.sleep(random.randint(10, 60)) - return - - print("Task requested in %ss" % ((datetime.utcnow() - t0).total_seconds())) - if 'error' in req: - raise Exception('Error from remote: %s' % (req['error'])) - - # No tasks ready for us yet, just wait... - if 'task_waiting' in req: - print('No tasks available at this time, waiting...\n') - # Note that after this sleep we have another ALIVE HTTP_TIMEOUT... - time.sleep(random.randint(1, 10)) - return - - success = True - run, task_id = req['run'], req['task_id'] - try: - pgn_file = run_games(worker_info, password, remote, run, task_id) - except: - sys.stderr.write('\nException running games:\n') - traceback.print_exc() - success = False - finally: - payload = { - 'username': worker_info['username'], - 'password': password, - 'run_id': str(run['_id']), - 'task_id': task_id - } + global ALIVE + + payload = {"worker_info": worker_info, "password": password} + try: - requests.post(remote + '/api/failed_task', data=json.dumps(payload), - headers={'Content-type': 'application/json'}, - timeout=HTTP_TIMEOUT) + print("Fetch task...") + if not get_rate(): + raise Exception("Near API limit") + + t0 = datetime.utcnow() + req = requests.post( + remote + "/api/request_version", + data=json.dumps(payload), + headers={"Content-type": "application/json"}, + timeout=HTTP_TIMEOUT, + ) + req = json.loads(req.text) + + if "version" not in req: + print("Incorrect username/password") + time.sleep(5) + sys.exit(1) + + if req["version"] > WORKER_VERSION: + print("Updating worker version to %s" % (req["version"])) + update() + print( + "Worker version checked successfully in %ss" + % ((datetime.utcnow() - t0).total_seconds()) + ) + + t0 = datetime.utcnow() + worker_info["rate"] = rate + req = requests.post( + remote + "/api/request_task", + data=json.dumps(payload), + headers={"Content-type": "application/json"}, + timeout=HTTP_TIMEOUT, + ) + req = json.loads(req.text) + except Exception as e: + sys.stderr.write("Exception accessing host:\n") + print(e) + # traceback.print_exc() + time.sleep(random.randint(10, 60)) + return + + print("Task requested in %ss" % ((datetime.utcnow() - t0).total_seconds())) + if "error" in req: + raise Exception("Error from remote: %s" % (req["error"])) + + # No tasks ready for us yet, just wait... + if "task_waiting" in req: + print("No tasks available at this time, waiting...\n") + # Note that after this sleep we have another ALIVE HTTP_TIMEOUT... + time.sleep(random.randint(1, 10)) + return + + success = True + run, task_id = req["run"], req["task_id"] + try: + pgn_file = run_games(worker_info, password, remote, run, task_id) except: - pass - - if success and ALIVE: - sleep = random.randint(1, 10) - print('Wait %d seconds before upload of PGN...' % (sleep)) - time.sleep(sleep) - if not 'spsa' in run['args']: + sys.stderr.write("\nException running games:\n") + traceback.print_exc() + success = False + finally: + payload = { + "username": worker_info["username"], + "password": password, + "run_id": str(run["_id"]), + "task_id": task_id, + } + try: + requests.post( + remote + "/api/failed_task", + data=json.dumps(payload), + headers={"Content-type": "application/json"}, + timeout=HTTP_TIMEOUT, + ) + except: + pass + + if success and ALIVE: + sleep = random.randint(1, 10) + print("Wait %d seconds before upload of PGN..." % (sleep)) + time.sleep(sleep) + if not "spsa" in run["args"]: + try: + with open(pgn_file, "r") as f: + data = f.read() + # Ignore non utf-8 characters in PGN file + if sys.version_info[0] == 2: + data = data.decode("utf-8", "ignore").encode( + "utf-8" + ) # Python 2 + else: + data = bytes(data, "utf-8").decode( + "utf-8", "ignore" + ) # Python 3 + payload["pgn"] = base64.b64encode( + zlib.compress(data.encode("utf-8")) + ).decode() + print( + "Uploading compressed PGN of %d bytes" % (len(payload["pgn"])) + ) + requests.post( + remote + "/api/upload_pgn", + data=json.dumps(payload), + headers={"Content-type": "application/json"}, + timeout=HTTP_TIMEOUT, + ) + except Exception as e: + sys.stderr.write("\nException PGN upload:\n") + print(e) + # traceback.print_exc() try: - with open(pgn_file, 'r') as f: - data = f.read() - # Ignore non utf-8 characters in PGN file - if sys.version_info[0] == 2: - data = data.decode('utf-8', 'ignore').encode('utf-8') # Python 2 - else: - data = bytes(data, 'utf-8').decode('utf-8', 'ignore') # Python 3 - payload['pgn'] = base64.b64encode(zlib.compress( - data.encode('utf-8'))).decode() - print('Uploading compressed PGN of %d bytes' % (len(payload['pgn']))) - requests.post(remote + '/api/upload_pgn', data=json.dumps(payload), - headers={'Content-type': 'application/json'}, - timeout=HTTP_TIMEOUT) - except Exception as e: - sys.stderr.write('\nException PGN upload:\n') - print(e) -# traceback.print_exc() + os.remove(pgn_file) + except: + pass + sys.stderr.write("Task exited\n") + + return success + + +def main(): + worker_dir = path.dirname(path.realpath(__file__)) + print("Worker started in " + worker_dir + " ...\n") + + signal.signal(signal.SIGINT, on_sigint) + signal.signal(signal.SIGTERM, on_sigint) + + config_file = path.join(worker_dir, "fishtest.cfg") + config = setup_config_file(config_file) + parser = OptionParser() + parser.add_option( + "-P", + "--protocol", + dest="protocol", + default=config.get("parameters", "protocol"), + ) + parser.add_option( + "-n", "--host", dest="host", default=config.get("parameters", "host") + ) + parser.add_option( + "-p", "--port", dest="port", default=config.get("parameters", "port") + ) + parser.add_option( + "-c", + "--concurrency", + dest="concurrency", + default=config.get("parameters", "concurrency"), + ) + parser.add_option( + "-m", + "--max_memory", + dest="max_memory", + default=config.get("parameters", "max_memory"), + ) + parser.add_option( + "-t", + "--min_threads", + dest="min_threads", + default=config.get("parameters", "min_threads"), + ) + (options, args) = parser.parse_args() + + if len(args) != 2: + # Try to read parameters from the the config file + username = config.get("login", "username") + password = config.get("login", "password", raw=True) + if len(username) != 0 and len(password) != 0: + args.extend([username, password]) + else: + sys.stderr.write("%s [username] [password]\n" % (sys.argv[0])) + sys.exit(1) + + # Write command line parameters to the config file + config.set("login", "username", args[0]) + config.set("login", "password", args[1]) + config.set("parameters", "protocol", options.protocol) + config.set("parameters", "host", options.host) + config.set("parameters", "port", options.port) + config.set("parameters", "concurrency", options.concurrency) + config.set("parameters", "max_memory", options.max_memory) + config.set("parameters", "min_threads", options.min_threads) + with open(config_file, "w") as f: + config.write(f) + + protocol = options.protocol.lower() + if protocol not in ["http", "https"]: + sys.stderr.write("Wrong protocol, use https or http\n") + sys.exit(1) + elif protocol == "http" and options.port == "443": + # Rewrite old port 443 to 80 + options.port = "80" + elif protocol == "https" and options.port == "80": + # Rewrite old port 80 to 443 + options.port = "443" + remote = "{}://{}:{}".format(protocol, options.host, options.port) + print("Worker version {} connecting to {}".format(WORKER_VERSION, remote)) + try: - os.remove(pgn_file) + cpu_count = min(int(options.concurrency), multiprocessing.cpu_count() - 1) except: - pass - sys.stderr.write('Task exited\n') + cpu_count = int(options.concurrency) + + if cpu_count <= 0: + sys.stderr.write("Not enough CPUs to run fishtest (it requires at least two)\n") + sys.exit(1) + + uname = platform.uname() + worker_info = { + "uname": uname[0] + " " + uname[2], + "architecture": platform.architecture(), + "concurrency": cpu_count, + "max_memory": int(options.max_memory), + "min_threads": int(options.min_threads), + "username": args[0], + "version": "%s:%s" % (WORKER_VERSION, sys.version_info[0]), + "unique_key": str(uuid.uuid4()), + } - return success + success = True + global ALIVE + while ALIVE: + if path.isfile(path.join(worker_dir, "fish.exit")): + break + if not success: + time.sleep(HTTP_TIMEOUT) + success = worker(worker_info, args[1], remote) -def main(): - worker_dir = path.dirname(path.realpath(__file__)) - print("Worker started in " + worker_dir + " ...\n") - - signal.signal(signal.SIGINT, on_sigint) - signal.signal(signal.SIGTERM, on_sigint) - - config_file = path.join(worker_dir, 'fishtest.cfg') - config = setup_config_file(config_file) - parser = OptionParser() - parser.add_option('-P', '--protocol', dest='protocol', default=config.get('parameters', 'protocol')) - parser.add_option('-n', '--host', dest='host', default=config.get('parameters', 'host')) - parser.add_option('-p', '--port', dest='port', default=config.get('parameters', 'port')) - parser.add_option('-c', '--concurrency', dest='concurrency', default=config.get('parameters', 'concurrency')) - parser.add_option('-m', '--max_memory', dest='max_memory', default=config.get('parameters', 'max_memory')) - parser.add_option('-t', '--min_threads', dest='min_threads', default=config.get('parameters', 'min_threads')) - (options, args) = parser.parse_args() - - if len(args) != 2: - # Try to read parameters from the the config file - username = config.get('login', 'username') - password = config.get('login', 'password', raw=True) - if len(username) != 0 and len(password) != 0: - args.extend([ username, password ]) - else: - sys.stderr.write('%s [username] [password]\n' % (sys.argv[0])) - sys.exit(1) - - # Write command line parameters to the config file - config.set('login', 'username', args[0]) - config.set('login', 'password', args[1]) - config.set('parameters', 'protocol', options.protocol) - config.set('parameters', 'host', options.host) - config.set('parameters', 'port', options.port) - config.set('parameters', 'concurrency', options.concurrency) - config.set('parameters', 'max_memory', options.max_memory) - config.set('parameters', 'min_threads', options.min_threads) - with open(config_file, 'w') as f: - config.write(f) - - protocol = options.protocol.lower() - if protocol not in ['http', 'https']: - sys.stderr.write('Wrong protocol, use https or http\n') - sys.exit(1) - elif protocol == 'http' and options.port == '443': - # Rewrite old port 443 to 80 - options.port = '80' - elif protocol == 'https' and options.port == '80': - # Rewrite old port 80 to 443 - options.port = '443' - remote = '{}://{}:{}'.format(protocol, options.host, options.port) - print('Worker version {} connecting to {}'.format(WORKER_VERSION, remote)) - - try: - cpu_count = min(int(options.concurrency), multiprocessing.cpu_count() - 1) - except: - cpu_count = int(options.concurrency) - - if cpu_count <= 0: - sys.stderr.write('Not enough CPUs to run fishtest (it requires at least two)\n') - sys.exit(1) - - uname = platform.uname() - worker_info = { - 'uname': uname[0] + ' ' + uname[2], - 'architecture': platform.architecture(), - 'concurrency': cpu_count, - 'max_memory': int(options.max_memory), - 'min_threads': int(options.min_threads), - 'username': args[0], - 'version': "%s:%s" % (WORKER_VERSION, sys.version_info[0]), - 'unique_key': str(uuid.uuid4()), - } - - success = True - global ALIVE - while ALIVE: - if path.isfile(path.join(worker_dir, 'fish.exit')): - break - if not success: - time.sleep(HTTP_TIMEOUT) - success = worker(worker_info, args[1], remote) - -if __name__ == '__main__': - main() + +if __name__ == "__main__": + main()