From e0dda9ddb02d65eb4676f1d6e7195f47667f8403 Mon Sep 17 00:00:00 2001 From: ppigazzini Date: Tue, 27 Feb 2024 14:17:49 +0100 Subject: [PATCH] Add multiple nets support to the server Required by the new SF architecture with multiple nets, see: https://github.com/official-stockfish/Stockfish/pull/4915 https://github.com/official-stockfish/Stockfish/pull/5068 Also: - fix the increment of the download of the nets - improve the net code to deal with a develepment server built from scratch --- server/fishtest/api.py | 3 ++ server/fishtest/rundb.py | 23 +++++---- server/fishtest/views.py | 98 ++++++++++++++++++++++---------------- server/tests/test_api.py | 4 +- server/tests/test_rundb.py | 8 ++-- 5 files changed, 78 insertions(+), 58 deletions(-) diff --git a/server/fishtest/api.py b/server/fishtest/api.py index 352fc3d031..51028d1b03 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -608,6 +608,9 @@ def download_nn(self): nn = self.request.rundb.get_nn(self.request.matchdict["id"]) if nn is None: raise exception_response(404) + else: + self.request.rundb.increment_nn_downloads(self.request.matchdict["id"]) + return HTTPFound( "https://data.stockfishchess.org/nn/" + self.request.matchdict["id"] ) diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index dad9a9ee6c..01e40a5016 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -50,7 +50,7 @@ # To make this practical we will eventually put all schemas # in a separate module "schemas.py". -net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name") +net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name") tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc") str_int = regex(r"[1-9]\d*", name="str_int") sha = regex(r"[a-f0-9]{40}", name="sha") @@ -111,8 +111,8 @@ "args": { "base_tag": str, "new_tag": str, - "base_net": net_name, - "new_net": net_name, + "base_nets": [net_name, ...], + "new_nets": [net_name, ...], "num_games": int, "tc": tc, "new_tc": tc, @@ -322,8 +322,8 @@ def new_run( msg_new="", base_signature="", new_signature="", - base_net=None, - new_net=None, + base_nets=None, + new_nets=None, rescheduled_from=None, base_same_as_master=None, start_time=None, @@ -344,8 +344,8 @@ def new_run( run_args = { "base_tag": base_tag, "new_tag": new_tag, - "base_net": base_net, - "new_net": new_net, + "base_nets": base_nets, + "new_nets": new_nets, "num_games": num_games, "tc": tc, "new_tc": new_tc, @@ -477,11 +477,10 @@ def update_nn(self, net): self.nndb.update_one({"name": net["name"]}, {"$set": net}) def get_nn(self, name): - nn = self.nndb.find_one({"name": name}, {"nn": 0}) - if nn: - self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}}) - return nn - return None + return self.nndb.find_one({"name": name}, {"nn": 0}) + + def increment_nn_downloads(self, name): + self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}}) def get_nns( self, user_id, user="", network_name="", master_only=False, limit=0, skip=0 diff --git a/server/fishtest/views.py b/server/fishtest/views.py index 135e231866..f6ecf16e96 100644 --- a/server/fishtest/views.py +++ b/server/fishtest/views.py @@ -729,35 +729,34 @@ def get_sha(branch, repo_url): return "", "" -def get_net(commit_sha, repo_url): - """Get the net from evaluate.h or ucioption.cpp in the repo""" +def get_nets(commit_sha, repo_url): + """Get the nets from evaluate.h or ucioption.cpp in the repo""" api_url = repo_url.replace( "https://github.com", "https://raw.githubusercontent.com" ) try: - net = None + nets = [] + pattern = re.compile("nn-[a-f0-9]{12}.nnue") url1 = api_url + "/" + commit_sha + "/src/evaluate.h" options = requests.get(url1).content.decode("utf-8") for line in options.splitlines(): if "EvalFileDefaultName" in line and "define" in line: - p = re.compile("nn-[a-z0-9]{12}.nnue") - m = p.search(line) + m = pattern.search(line) if m: - net = m.group(0) + nets.append(m.group(0)) - if net: - return net + if nets: + return nets url2 = api_url + "/" + commit_sha + "/src/ucioption.cpp" options = requests.get(url2).content.decode("utf-8") for line in options.splitlines(): if "EvalFile" in line and "Option" in line: - p = re.compile("nn-[a-z0-9]{12}.nnue") - m = p.search(line) + m = pattern.search(line) if m: - net = m.group(0) - return net + nets.append(m.group(0)) + return nets except: raise Exception("Unable to access developer repository: " + api_url) @@ -919,19 +918,23 @@ def strip_message(m): ) data["base_same_as_master"] = master_diff.text == "" - # Test existence of net - new_net = get_net(data["resolved_new"], data["tests_repo"]) - if new_net: - if not request.rundb.get_nn(new_net): - raise Exception( - "The net {}, used by {}, is not " - "known to Fishtest. Please upload it to: " - "{}/upload.".format(new_net, data["new_tag"], request.host_url) + # Store nets info + data["base_nets"] = get_nets(data["resolved_base"], data["tests_repo"]) + data["new_nets"] = get_nets(data["resolved_new"], data["tests_repo"]) + + # Test existence of nets + missing_nets = [] + for net_name in set(data["base_nets"]) | set(data["new_nets"]): + net = request.rundb.get_nn(net_name) + if net is None: + missing_nets.append(net_name) + if missing_nets: + raise Exception( + "Missing net(s). Please upload to: {} the following net(s): {}".format( + request.host_url, + ", ".join(missing_nets), ) - - # Store net info - data["new_net"] = new_net - data["base_net"] = get_net(data["resolved_base"], data["tests_repo"]) + ) # Integer parameters data["threads"] = int(request.POST["threads"]) @@ -1003,25 +1006,32 @@ def del_tasks(run): def update_nets(request, run): run_id = str(run["_id"]) data = run["args"] + base_nets, new_nets, missing_nets = [], [], [] + for net_name in set(data["base_nets"]) | set(data["new_nets"]): + net = request.rundb.get_nn(net_name) + if net is None: + # This should never happen + missing_nets.append(net_name) + else: + if net_name in data["base_nets"]: + base_nets.append(net) + if net_name in data["new_nets"]: + new_nets.append(net) + if missing_nets: + raise Exception( + "Missing net(s). Please upload to {} the following net(s): {}".format( + request.host_url, + ", ".join(missing_nets), + ) + ) + if run["base_same_as_master"]: - base_net = data["base_net"] - if base_net: - net = request.rundb.get_nn(base_net) - if not net: - # Should never happen: - raise Exception( - "The net {}, used by {}, is not " - "known to Fishtest. Please upload it to: " - "{}/upload.".format(base_net, data["base_tag"], request.host_url) - ) + for net in base_nets: if "is_master" not in net: net["is_master"] = True request.rundb.update_nn(net) - new_net = data["new_net"] - if new_net: - net = request.rundb.get_nn(new_net) - if not net: - return + + for net in new_nets: if "first_test" not in net: net["first_test"] = {"id": run_id, "date": datetime.now(timezone.utc)} net["last_test"] = {"id": run_id, "date": datetime.now(timezone.utc)} @@ -1228,7 +1238,10 @@ def tests_approve(request): if run is None: request.session.flash(message, "error") else: - update_nets(request, run) + try: + update_nets(request, run) + except Exception as e: + request.session.flash(str(e), "error") request.actiondb.approve_run(username=username, run=run) cached_flash(request, message) return home(request) @@ -1367,11 +1380,13 @@ def tests_view(request): "new_options", "resolved_new", "new_net", + "new_nets", "base_tag", "base_signature", "base_options", "resolved_base", "base_net", + "base_nets", "sprt", "num_games", "spsa", @@ -1401,6 +1416,9 @@ def tests_view(request): if name == "base_tag" and "msg_base" in run["args"]: value += " (" + run["args"]["msg_base"][:50] + ")" + if name in ("new_nets", "base_nets"): + value = ", ".join(value) + if name == "sprt" and value != "-": value = "elo0: {:.2f} alpha: {:.2f} elo1: {:.2f} beta: {:.2f} state: {} ({})".format( value["elo0"], diff --git a/server/tests/test_api.py b/server/tests/test_api.py index b7e4bfc8d2..3ae264dfe3 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -39,8 +39,8 @@ def new_run(self, add_tasks=0): msg_new="Super stuff", base_signature="123456", new_signature="654321", - base_net="nn-0000000000a0.nnue", - new_net="nn-0000000000a0.nnue", + base_nets=["nn-0000000000a0.nnue"], + new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"], rescheduled_from="653db116cc309ae839563103", base_same_as_master=False, tests_repo="https://google.com", diff --git a/server/tests/test_rundb.py b/server/tests/test_rundb.py index eb5c44c215..01fd676f09 100644 --- a/server/tests/test_rundb.py +++ b/server/tests/test_rundb.py @@ -72,8 +72,8 @@ def test_10_create_run(self): msg_new="Super stuff", base_signature="123456", new_signature="654321", - base_net="nn-0000000000a0.nnue", - new_net="nn-0000000000a0.nnue", + base_nets=["nn-0000000000a0.nnue"], + new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"], rescheduled_from="653db116cc309ae839563103", base_same_as_master=False, tests_repo="https://google.com", @@ -113,8 +113,8 @@ def test_10_create_run(self): msg_new="Super stuff", base_signature="123456", new_signature="654321", - base_net="nn-0000000000a0.nnue", - new_net="nn-0000000000a0.nnue", + base_nets=["nn-0000000000a0.nnue"], + new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"], rescheduled_from="653db116cc309ae839563103", base_same_as_master=False, tests_repo="https://google.com",