Skip to content

Commit

Permalink
Add multiple nets support to the server
Browse files Browse the repository at this point in the history
Required by the new SF architecture with multiple nets, see:
official-stockfish/Stockfish#4915
official-stockfish/Stockfish#5068

Also:
- fix the increment of the download of the nets
- improve the net code to deal with a develepment server built from scratch
  • Loading branch information
ppigazzini committed Feb 28, 2024
1 parent 4bee31b commit e0dda9d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 58 deletions.
3 changes: 3 additions & 0 deletions server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def download_nn(self):
nn = self.request.rundb.get_nn(self.request.matchdict["id"])
if nn is None:
raise exception_response(404)
else:
self.request.rundb.increment_nn_downloads(self.request.matchdict["id"])

return HTTPFound(
"https://data.stockfishchess.org/nn/" + self.request.matchdict["id"]
)
Expand Down
23 changes: 11 additions & 12 deletions server/fishtest/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# To make this practical we will eventually put all schemas
# in a separate module "schemas.py".

net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name")
net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name")
tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc")
str_int = regex(r"[1-9]\d*", name="str_int")
sha = regex(r"[a-f0-9]{40}", name="sha")
Expand Down Expand Up @@ -111,8 +111,8 @@
"args": {
"base_tag": str,
"new_tag": str,
"base_net": net_name,
"new_net": net_name,
"base_nets": [net_name, ...],
"new_nets": [net_name, ...],
"num_games": int,
"tc": tc,
"new_tc": tc,
Expand Down Expand Up @@ -322,8 +322,8 @@ def new_run(
msg_new="",
base_signature="",
new_signature="",
base_net=None,
new_net=None,
base_nets=None,
new_nets=None,
rescheduled_from=None,
base_same_as_master=None,
start_time=None,
Expand All @@ -344,8 +344,8 @@ def new_run(
run_args = {
"base_tag": base_tag,
"new_tag": new_tag,
"base_net": base_net,
"new_net": new_net,
"base_nets": base_nets,
"new_nets": new_nets,
"num_games": num_games,
"tc": tc,
"new_tc": new_tc,
Expand Down Expand Up @@ -477,11 +477,10 @@ def update_nn(self, net):
self.nndb.update_one({"name": net["name"]}, {"$set": net})

def get_nn(self, name):
nn = self.nndb.find_one({"name": name}, {"nn": 0})
if nn:
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})
return nn
return None
return self.nndb.find_one({"name": name}, {"nn": 0})

def increment_nn_downloads(self, name):
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})

def get_nns(
self, user_id, user="", network_name="", master_only=False, limit=0, skip=0
Expand Down
98 changes: 58 additions & 40 deletions server/fishtest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,35 +729,34 @@ def get_sha(branch, repo_url):
return "", ""


def get_net(commit_sha, repo_url):
"""Get the net from evaluate.h or ucioption.cpp in the repo"""
def get_nets(commit_sha, repo_url):
"""Get the nets from evaluate.h or ucioption.cpp in the repo"""
api_url = repo_url.replace(
"https://github.com", "https://raw.githubusercontent.com"
)
try:
net = None
nets = []
pattern = re.compile("nn-[a-f0-9]{12}.nnue")

url1 = api_url + "/" + commit_sha + "/src/evaluate.h"
options = requests.get(url1).content.decode("utf-8")
for line in options.splitlines():
if "EvalFileDefaultName" in line and "define" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
m = pattern.search(line)
if m:
net = m.group(0)
nets.append(m.group(0))

if net:
return net
if nets:
return nets

url2 = api_url + "/" + commit_sha + "/src/ucioption.cpp"
options = requests.get(url2).content.decode("utf-8")
for line in options.splitlines():
if "EvalFile" in line and "Option" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
m = pattern.search(line)
if m:
net = m.group(0)
return net
nets.append(m.group(0))
return nets
except:
raise Exception("Unable to access developer repository: " + api_url)

Expand Down Expand Up @@ -919,19 +918,23 @@ def strip_message(m):
)
data["base_same_as_master"] = master_diff.text == ""

# Test existence of net
new_net = get_net(data["resolved_new"], data["tests_repo"])
if new_net:
if not request.rundb.get_nn(new_net):
raise Exception(
"The net {}, used by {}, is not "
"known to Fishtest. Please upload it to: "
"{}/upload.".format(new_net, data["new_tag"], request.host_url)
# Store nets info
data["base_nets"] = get_nets(data["resolved_base"], data["tests_repo"])
data["new_nets"] = get_nets(data["resolved_new"], data["tests_repo"])

# Test existence of nets
missing_nets = []
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
net = request.rundb.get_nn(net_name)
if net is None:
missing_nets.append(net_name)
if missing_nets:
raise Exception(
"Missing net(s). Please upload to: {} the following net(s): {}".format(
request.host_url,
", ".join(missing_nets),
)

# Store net info
data["new_net"] = new_net
data["base_net"] = get_net(data["resolved_base"], data["tests_repo"])
)

# Integer parameters
data["threads"] = int(request.POST["threads"])
Expand Down Expand Up @@ -1003,25 +1006,32 @@ def del_tasks(run):
def update_nets(request, run):
run_id = str(run["_id"])
data = run["args"]
base_nets, new_nets, missing_nets = [], [], []
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
net = request.rundb.get_nn(net_name)
if net is None:
# This should never happen
missing_nets.append(net_name)
else:
if net_name in data["base_nets"]:
base_nets.append(net)
if net_name in data["new_nets"]:
new_nets.append(net)
if missing_nets:
raise Exception(
"Missing net(s). Please upload to {} the following net(s): {}".format(
request.host_url,
", ".join(missing_nets),
)
)

if run["base_same_as_master"]:
base_net = data["base_net"]
if base_net:
net = request.rundb.get_nn(base_net)
if not net:
# Should never happen:
raise Exception(
"The net {}, used by {}, is not "
"known to Fishtest. Please upload it to: "
"{}/upload.".format(base_net, data["base_tag"], request.host_url)
)
for net in base_nets:
if "is_master" not in net:
net["is_master"] = True
request.rundb.update_nn(net)
new_net = data["new_net"]
if new_net:
net = request.rundb.get_nn(new_net)
if not net:
return

for net in new_nets:
if "first_test" not in net:
net["first_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
net["last_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
Expand Down Expand Up @@ -1228,7 +1238,10 @@ def tests_approve(request):
if run is None:
request.session.flash(message, "error")
else:
update_nets(request, run)
try:
update_nets(request, run)
except Exception as e:
request.session.flash(str(e), "error")
request.actiondb.approve_run(username=username, run=run)
cached_flash(request, message)
return home(request)
Expand Down Expand Up @@ -1367,11 +1380,13 @@ def tests_view(request):
"new_options",
"resolved_new",
"new_net",
"new_nets",
"base_tag",
"base_signature",
"base_options",
"resolved_base",
"base_net",
"base_nets",
"sprt",
"num_games",
"spsa",
Expand Down Expand Up @@ -1401,6 +1416,9 @@ def tests_view(request):
if name == "base_tag" and "msg_base" in run["args"]:
value += " (" + run["args"]["msg_base"][:50] + ")"

if name in ("new_nets", "base_nets"):
value = ", ".join(value)

if name == "sprt" and value != "-":
value = "elo0: {:.2f} alpha: {:.2f} elo1: {:.2f} beta: {:.2f} state: {} ({})".format(
value["elo0"],
Expand Down
4 changes: 2 additions & 2 deletions server/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def new_run(self, add_tasks=0):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down
8 changes: 4 additions & 4 deletions server/tests/test_rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_10_create_run(self):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down Expand Up @@ -113,8 +113,8 @@ def test_10_create_run(self):
msg_new="Super stuff",
base_signature="123456",
new_signature="654321",
base_net="nn-0000000000a0.nnue",
new_net="nn-0000000000a0.nnue",
base_nets=["nn-0000000000a0.nnue"],
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
rescheduled_from="653db116cc309ae839563103",
base_same_as_master=False,
tests_repo="https://google.com",
Expand Down

0 comments on commit e0dda9d

Please sign in to comment.