diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index dad9a9ee6c..56a1756024 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -50,7 +50,7 @@ # To make this practical we will eventually put all schemas # in a separate module "schemas.py". -net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name") +net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name") tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc") str_int = regex(r"[1-9]\d*", name="str_int") sha = regex(r"[a-f0-9]{40}", name="sha") @@ -111,8 +111,8 @@ "args": { "base_tag": str, "new_tag": str, - "base_net": net_name, - "new_net": net_name, + "base_nets": [net_name, ...], + "new_nets": [net_name, ...], "num_games": int, "tc": tc, "new_tc": tc, @@ -322,8 +322,8 @@ def new_run( msg_new="", base_signature="", new_signature="", - base_net=None, - new_net=None, + base_nets=None, + new_nets=None, rescheduled_from=None, base_same_as_master=None, start_time=None, @@ -344,8 +344,8 @@ def new_run( run_args = { "base_tag": base_tag, "new_tag": new_tag, - "base_net": base_net, - "new_net": new_net, + "base_nets": base_nets, + "new_nets": new_nets, "num_games": num_games, "tc": tc, "new_tc": new_tc, diff --git a/server/fishtest/views.py b/server/fishtest/views.py index 135e231866..9a615f7e08 100644 --- a/server/fishtest/views.py +++ b/server/fishtest/views.py @@ -729,35 +729,34 @@ def get_sha(branch, repo_url): return "", "" -def get_net(commit_sha, repo_url): - """Get the net from evaluate.h or ucioption.cpp in the repo""" +def get_nets(commit_sha, repo_url): + """Get the nets from evaluate.h or ucioption.cpp in the repo""" api_url = repo_url.replace( "https://github.com", "https://raw.githubusercontent.com" ) try: - net = None + nets = [] + pattern = re.compile("nn-[a-f0-9]{12}.nnue") url1 = api_url + "/" + commit_sha + "/src/evaluate.h" options = requests.get(url1).content.decode("utf-8") for line in options.splitlines(): if "EvalFileDefaultName" in line and "define" in line: - p = re.compile("nn-[a-z0-9]{12}.nnue") - m = p.search(line) + m = pattern.search(line) if m: - net = m.group(0) + nets.append(m.group(0)) - if net: - return net + if nets: + return nets url2 = api_url + "/" + commit_sha + "/src/ucioption.cpp" options = requests.get(url2).content.decode("utf-8") for line in options.splitlines(): if "EvalFile" in line and "Option" in line: - p = re.compile("nn-[a-z0-9]{12}.nnue") - m = p.search(line) + m = pattern.search(line) if m: - net = m.group(0) - return net + nets.append(m.group(0)) + return nets except: raise Exception("Unable to access developer repository: " + api_url) @@ -919,9 +918,9 @@ def strip_message(m): ) data["base_same_as_master"] = master_diff.text == "" - # Test existence of net - new_net = get_net(data["resolved_new"], data["tests_repo"]) - if new_net: + # Test existence of nets + new_nets = get_nets(data["resolved_new"], data["tests_repo"]) + for new_net in new_nets: if not request.rundb.get_nn(new_net): raise Exception( "The net {}, used by {}, is not " @@ -929,9 +928,9 @@ def strip_message(m): "{}/upload.".format(new_net, data["new_tag"], request.host_url) ) - # Store net info - data["new_net"] = new_net - data["base_net"] = get_net(data["resolved_base"], data["tests_repo"]) + # Store nets info + data["new_nets"] = new_nets + data["base_nets"] = get_nets(data["resolved_base"], data["tests_repo"]) # Integer parameters data["threads"] = int(request.POST["threads"]) @@ -1004,8 +1003,8 @@ def update_nets(request, run): run_id = str(run["_id"]) data = run["args"] if run["base_same_as_master"]: - base_net = data["base_net"] - if base_net: + base_nets = data["base_nets"] + for base_net in base_nets: net = request.rundb.get_nn(base_net) if not net: # Should never happen: @@ -1017,8 +1016,8 @@ def update_nets(request, run): if "is_master" not in net: net["is_master"] = True request.rundb.update_nn(net) - new_net = data["new_net"] - if new_net: + new_nets = data["new_nets"] + for new_net in new_nets: net = request.rundb.get_nn(new_net) if not net: return @@ -1366,12 +1365,12 @@ def tests_view(request): "new_signature", "new_options", "resolved_new", - "new_net", + "new_nets", "base_tag", "base_signature", "base_options", "resolved_base", - "base_net", + "base_nets", "sprt", "num_games", "spsa", diff --git a/server/tests/test_api.py b/server/tests/test_api.py index b7e4bfc8d2..3ae264dfe3 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -39,8 +39,8 @@ def new_run(self, add_tasks=0): msg_new="Super stuff", base_signature="123456", new_signature="654321", - base_net="nn-0000000000a0.nnue", - new_net="nn-0000000000a0.nnue", + base_nets=["nn-0000000000a0.nnue"], + new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"], rescheduled_from="653db116cc309ae839563103", base_same_as_master=False, tests_repo="https://google.com", diff --git a/server/tests/test_rundb.py b/server/tests/test_rundb.py index eb5c44c215..92bddee6d5 100644 --- a/server/tests/test_rundb.py +++ b/server/tests/test_rundb.py @@ -72,8 +72,8 @@ def test_10_create_run(self): msg_new="Super stuff", base_signature="123456", new_signature="654321", - base_net="nn-0000000000a0.nnue", - new_net="nn-0000000000a0.nnue", + base_nets="nn-0000000000a0.nnue", + new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"], rescheduled_from="653db116cc309ae839563103", base_same_as_master=False, tests_repo="https://google.com",