Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple nets support to the server #1902

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def download_nn(self):
nn = self.request.rundb.get_nn(self.request.matchdict["id"])
if nn is None:
raise exception_response(404)
else:
self.request.rundb.increment_nn_downloads(self.request.matchdict["id"])

return HTTPFound(
"https://data.stockfishchess.org/nn/" + self.request.matchdict["id"]
)
Expand Down
23 changes: 11 additions & 12 deletions server/fishtest/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# To make this practical we will eventually put all schemas
# in a separate module "schemas.py".

net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name")
net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name")
tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc")
str_int = regex(r"[1-9]\d*", name="str_int")
sha = regex(r"[a-f0-9]{40}", name="sha")
Expand Down Expand Up @@ -111,8 +111,8 @@
"args": {
"base_tag": str,
"new_tag": str,
"base_net": net_name,
"new_net": net_name,
"base_nets": [net_name, ...],
"new_nets": [net_name, ...],
"num_games": int,
"tc": tc,
"new_tc": tc,
Expand Down Expand Up @@ -322,8 +322,8 @@ def new_run(
msg_new="",
base_signature="",
new_signature="",
base_net=None,
new_net=None,
base_nets=None,
new_nets=None,
rescheduled_from=None,
base_same_as_master=None,
start_time=None,
Expand All @@ -344,8 +344,8 @@ def new_run(
run_args = {
"base_tag": base_tag,
"new_tag": new_tag,
"base_net": base_net,
"new_net": new_net,
"base_nets": base_nets,
"new_nets": new_nets,
"num_games": num_games,
"tc": tc,
"new_tc": new_tc,
Expand Down Expand Up @@ -477,11 +477,10 @@ def update_nn(self, net):
self.nndb.update_one({"name": net["name"]}, {"$set": net})

def get_nn(self, name):
nn = self.nndb.find_one({"name": name}, {"nn": 0})
if nn:
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})
return nn
return None
return self.nndb.find_one({"name": name}, {"nn": 0})

def increment_nn_downloads(self, name):
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})

def get_nns(
self, user_id, user="", network_name="", master_only=False, limit=0, skip=0
Expand Down
98 changes: 58 additions & 40 deletions server/fishtest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,35 +729,34 @@ def get_sha(branch, repo_url):
return "", ""


def get_net(commit_sha, repo_url):
"""Get the net from evaluate.h or ucioption.cpp in the repo"""
def get_nets(commit_sha, repo_url):
"""Get the nets from evaluate.h or ucioption.cpp in the repo"""
api_url = repo_url.replace(
"https://github.com", "https://raw.githubusercontent.com"
)
try:
net = None
nets = []
pattern = re.compile("nn-[a-f0-9]{12}.nnue")

url1 = api_url + "/" + commit_sha + "/src/evaluate.h"
options = requests.get(url1).content.decode("utf-8")
for line in options.splitlines():
if "EvalFileDefaultName" in line and "define" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
m = pattern.search(line)
if m:
net = m.group(0)
nets.append(m.group(0))

if net:
return net
if nets:
return nets

url2 = api_url + "/" + commit_sha + "/src/ucioption.cpp"
options = requests.get(url2).content.decode("utf-8")
for line in options.splitlines():
if "EvalFile" in line and "Option" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
m = pattern.search(line)
if m:
net = m.group(0)
return net
nets.append(m.group(0))
return nets
except:
raise Exception("Unable to access developer repository: " + api_url)

Expand Down Expand Up @@ -919,19 +918,23 @@ def strip_message(m):
)
data["base_same_as_master"] = master_diff.text == ""

# Test existence of net
new_net = get_net(data["resolved_new"], data["tests_repo"])
if new_net:
if not request.rundb.get_nn(new_net):
raise Exception(
"The net {}, used by {}, is not "
"known to Fishtest. Please upload it to: "
"{}/upload.".format(new_net, data["new_tag"], request.host_url)
# Store nets info
data["base_nets"] = get_nets(data["resolved_base"], data["tests_repo"])
data["new_nets"] = get_nets(data["resolved_new"], data["tests_repo"])

# Test existence of nets
missing_nets = []
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
net = request.rundb.get_nn(net_name)
if net is None:
missing_nets.append(net_name)
if missing_nets:
raise Exception(
"Missing net(s). Please upload to: {} the following net(s): {}".format(
request.host_url,
", ".join(missing_nets),
)

# Store net info
data["new_net"] = new_net
data["base_net"] = get_net(data["resolved_base"], data["tests_repo"])
)

# Integer parameters
data["threads"] = int(request.POST["threads"])
Expand Down Expand Up @@ -1003,25 +1006,32 @@ def del_tasks(run):
def update_nets(request, run):
run_id = str(run["_id"])
data = run["args"]
base_nets, new_nets, missing_nets = [], [], []
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
net = request.rundb.get_nn(net_name)
if net is None:
# This should never happen
missing_nets.append(net_name)
else:
if net_name in data["base_nets"]:
base_nets.append(net)
if net_name in data["new_nets"]:
new_nets.append(net)
if missing_nets:
raise Exception(
"Missing net(s). Please upload to {} the following net(s): {}".format(
request.host_url,
", ".join(missing_nets),
)
)

if run["base_same_as_master"]:
base_net = data["base_net"]
if base_net:
net = request.rundb.get_nn(base_net)
if not net:
# Should never happen:
raise Exception(
"The net {}, used by {}, is not "
"known to Fishtest. Please upload it to: "
"{}/upload.".format(base_net, data["base_tag"], request.host_url)
)
for net in base_nets:
if "is_master" not in net:
net["is_master"] = True
request.rundb.update_nn(net)
new_net = data["new_net"]
if new_net:
net = request.rundb.get_nn(new_net)
if not net:
return

for net in new_nets:
if "first_test" not in net:
net["first_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
net["last_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
Expand Down Expand Up @@ -1228,7 +1238,10 @@ def tests_approve(request):
if run is None:
request.session.flash(message, "error")
else:
update_nets(request, run)
try:
update_nets(request, run)
except Exception as e:
request.session.flash(str(e), "error")
request.actiondb.approve_run(username=username, run=run)
cached_flash(request, message)
return home(request)
Expand Down Expand Up @@ -1367,11 +1380,13 @@ def tests_view(request):
"new_options",
"resolved_new",
"new_net",
"new_nets",
"base_tag",
"base_signature",
"base_options",
"resolved_base",
"base_net",
"base_nets",
"sprt",
"num_games",
"spsa",
Expand Down Expand Up @@ -1401,6 +1416,9 @@ def tests_view(request):
if name == "base_tag" and "msg_base" in run["args"]:
value += " (" + run["args"]["msg_base"][:50] + ")"

if name in ("new_nets", "base_nets"):
value = ", ".join(value)

if name == "sprt" and value != "-":
value = "elo0: {:.2f} alpha: {:.2f} elo1: {:.2f} beta: {:.2f} state: {} ({})".format(
value["elo0"],
Expand Down
4 changes: 2 additions & 2 deletions server/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def new_run(self, add_tasks=0):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down
8 changes: 4 additions & 4 deletions server/tests/test_rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_10_create_run(self):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down Expand Up @@ -113,8 +113,8 @@ def test_10_create_run(self):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down
Loading