Skip to content

Commit

Permalink
Add dual nets support to the worker
Browse files Browse the repository at this point in the history
  • Loading branch information
ppigazzini committed Dec 13, 2023
1 parent 140699a commit e2f7a5a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
39 changes: 20 additions & 19 deletions worker/games.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def required_net(engine):
) as p:
for line in iter(p.stdout.readline, ""):
if "EvalFile" in line:
net = line.split(" ")[6].strip()
match = re.search("nn-[a-f0-9]{12}.nnue", line)
if match:
net = match.group(0)
except (OSError, subprocess.SubprocessError) as e:
raise WorkerException(
"Unable to obtain name for required net. Error: {}".format(str(e))
Expand All @@ -258,31 +260,31 @@ def required_net(engine):
return net


def required_net_from_source():
def required_nets_from_source():
"""Parse evaluate.h and ucioption.cpp to find default net"""
net = None

nets = []
pattern = re.compile("nn-[a-f0-9]{12}.nnue")
# NNUE code after binary embedding (Aug 2020)
with open("evaluate.h", "r") as srcfile:
for line in srcfile:
if "EvalFileDefaultName" in line and "define" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
if m:
net = m.group(0)
if net:
return net
if "define" in line and (
"EvalFileDefaultNameBig" in line or "EvalFileDefaultNameSmall" in line
):
match = pattern.search(line)
if match:
nets.append(match.group(0))
if nets:
return nets

# NNUE code before binary embedding (Aug 2020)
with open("ucioption.cpp", "r") as srcfile:
for line in srcfile:
if "EvalFile" in line and "Option" in line:
p = re.compile("nn-[a-z0-9]{12}.nnue")
m = p.search(line)
if m:
net = m.group(0)
match = pattern.search(line)
if match:
nets.append(match.group(0))

return net
return nets


def download_net(remote, testing_dir, net):
Expand Down Expand Up @@ -438,7 +440,7 @@ def download_from_github(
try:
blob = download_from_github_raw(item, owner=owner, repo=repo, branch=branch)
except:
print("Downloading {} failed. Trying the github api.".format(item))
print("Downloading {} failed. Trying the GitHub api.".format(item))
try:
blob = download_from_github_api(item, owner=owner, repo=repo, branch=branch)
except:
Expand Down Expand Up @@ -674,8 +676,7 @@ def setup_engine(
prefix = os.path.commonprefix([n.filename for n in file_list])
os.chdir(tmp_dir / prefix / "src")

net = required_net_from_source()
if net:
for net in required_nets_from_source():
print("Build uses default net: ", net)
establish_validated_net(remote, testing_dir, net)
shutil.copyfile(testing_dir / net, net)
Expand Down
2 changes: 1 addition & 1 deletion worker/sri.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"__version": 223, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "wIWFmhiGottIi/NoTKcDorUn0gFeNZPfMywCnNYOgNEZjvltXZEIeWkm9P+vQAFe", "games.py": "k0HHaT2Jw/RVoWgZOkywOZX2dgWKqznXKFKBpFTQkp9apmlYK3ecX9ilpoUvW9ei"}
{"__version": 223, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "wIWFmhiGottIi/NoTKcDorUn0gFeNZPfMywCnNYOgNEZjvltXZEIeWkm9P+vQAFe", "games.py": "17cer+9vH0J9Rm2gXe6HMPll6snK+dm7lpfcMp+PzJKE045Wzov+YDBGig1qqLfB"}

0 comments on commit e2f7a5a

Please sign in to comment.