Skip to content

Commit

Permalink
Code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ppigazzini committed Apr 20, 2024
1 parent 9eb0dc9 commit 32af845
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 75 deletions.
52 changes: 18 additions & 34 deletions server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def strip_run(run):
# a deep copy, avoiding copies of a few large lists.
stripped = {}
for k1, v1 in run.items():
if k1 in ["tasks", "bad_tasks"]:
if k1 in ("tasks", "bad_tasks"):
stripped[k1] = []
elif k1 == "args":
stripped[k1] = {}
Expand All @@ -60,9 +60,8 @@ def strip_run(run):
stripped[k1] = copy.deepcopy(v1)

# and some string conversions
stripped["_id"] = str(run["_id"])
stripped["start_time"] = str(run["start_time"])
stripped["last_updated"] = str(run["last_updated"])
for key in ("_id", "start_time", "last_updated"):
stripped[key] = str(run[key])

return stripped

Expand Down Expand Up @@ -211,10 +210,7 @@ def worker_name(self):
def cpu_hours(self):
username = self.get_username()
user = self.request.userdb.user_cache.find_one({"username": username})
if not user:
return -1
else:
return user["cpu_hours"]
return -1 if user is None else user["cpu_hours"]

def message(self):
return self.request_body.get("message", "")
Expand All @@ -238,9 +234,8 @@ def active_runs(self):
active = {}
for run in runs:
# some string conversions
run["_id"] = str(run["_id"])
run["start_time"] = str(run["start_time"])
run["last_updated"] = str(run["last_updated"])
for key in ("_id", "start_time", "last_updated"):
run[key] = str(run[key])
active[str(run["_id"])] = run
return active

Expand Down Expand Up @@ -282,9 +277,8 @@ def finished_runs(self):
finished = {}
for run in runs:
# some string conversions
run["_id"] = str(run["_id"])
run["start_time"] = str(run["start_time"])
run["last_updated"] = str(run["last_updated"])
for key in ("_id", "start_time", "last_updated"):
run[key] = str(run[key])
finished[str(run["_id"])] = run
return finished

Expand Down Expand Up @@ -381,17 +375,17 @@ def calc_elo(self):

is_ptnml = all(
value is not None and value.replace(".", "").replace("-", "").isdigit()
for value in [LL, LD, DDWL, WD, WW]
for value in (LL, LD, DDWL, WD, WW)
)

is_ptnml = is_ptnml and all(int(value) >= 0 for value in [LL, LD, DDWL, WD, WW])
is_ptnml = is_ptnml and all(int(value) >= 0 for value in (LL, LD, DDWL, WD, WW))

is_wdl = not is_ptnml and all(
value is not None and value.replace(".", "").replace("-", "").isdigit()
for value in [W, D, L]
for value in (W, D, L)
)

is_wdl = is_wdl and all(int(value) >= 0 for value in [W, D, L])
is_wdl = is_wdl and all(int(value) >= 0 for value in (W, D, L))

if not is_ptnml and not is_wdl:
self.handle_error(
Expand Down Expand Up @@ -443,7 +437,7 @@ def calc_elo(self):
badEloValues = (
not all(
value.replace(".", "").replace("-", "").isdigit()
for value in [elo0, elo1]
for value in (elo0, elo1)
)
or float(elo1) < float(elo0) + 0.5
or abs(float(elo0)) > 10
Expand Down Expand Up @@ -484,21 +478,11 @@ def request_task(self):

# Strip the run of unneccesary information
run = result["run"]
min_run = {"_id": str(run["_id"]), "args": run["args"], "tasks": []}
if int(str(worker_info["version"]).split(":")[0]) > 64:
task = run["tasks"][result["task_id"]]
min_task = {"num_games": task["num_games"]}
if "stats" in task:
min_task["stats"] = task["stats"]
min_task["start"] = task["start"]
min_run["my_task"] = min_task
else:
for task in run["tasks"]:
min_task = {"num_games": task["num_games"]}
if "stats" in task:
min_task["stats"] = task["stats"]
min_run["tasks"].append(min_task)

task = run["tasks"][result["task_id"]]
min_task = {"num_games": task["num_games"], "start": task["start"]}
if "stats" in task:
min_task["stats"] = task["stats"]
min_run = {"_id": str(run["_id"]), "args": run["args"], "my_tasks": min_task}
result["run"] = min_run
return self.add_time(result)

Expand Down
3 changes: 1 addition & 2 deletions server/fishtest/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,8 +1536,7 @@ def generate_spsa(self, run):
spsa = run["args"]["spsa"]

# Generate the next set of tuning parameters
iter_local = spsa["iter"] + 1 # assume at least one completed,
# and avoid division by zero
iter_local = spsa["iter"] + 1 # start from 1 to avoid division by zero
for param in spsa["params"]:
c = param["c"] / iter_local ** spsa["gamma"]
flip = 1 if random.getrandbits(1) else -1
Expand Down
14 changes: 6 additions & 8 deletions server/fishtest/userdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@ def __init__(self, db):

def find_by_username(self, name):
with self.user_lock:
if name in self.cache:
u = self.cache[name]
if u["time"] > time.time() - 120:
return u["user"]
user = self.cache.get(name)
if user and time.time() < user["time"] + 120:
return user["user"]
user = self.users.find_one({"username": name})
if not user:
return None
self.cache[name] = {"user": user, "time": time.time()}
if user is not None:
self.cache[name] = {"user": user, "time": time.time()}
return user

def find_by_email(self, email):
Expand Down Expand Up @@ -96,7 +94,7 @@ def get_user(self, username):

def get_user_groups(self, username):
user = self.get_user(username)
if user:
if user is not None:
groups = user["groups"]
return groups

Expand Down
37 changes: 17 additions & 20 deletions server/fishtest/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def get_chi2(tasks, exclude_workers=set()):
# So we combine the ww and ll frequencies with the wd and ld frequencies.
wld = [float(p[4] + p[3]), float(p[0] + p[1]), float(p[2])]
if key in users:
for idx in range(len(wld)):
users[key][idx] += wld[idx]
users[key] = [
user_val + wld_val for user_val, wld_val in zip(users[key], wld)
]
else:
users[key] = wld

Expand All @@ -103,11 +104,10 @@ def get_chi2(tasks, exclude_workers=set()):
if grand_total == 0:
return default_results
expected = numpy.outer(row_sums, column_sums) / grand_total
keys = list(users)
filtering_done = True
for idx in range(len(keys)):
if min(expected[idx]) <= 5:
del users[keys[idx]]
for key, expected_value in zip(users.keys(), expected):
if min(expected_value) <= 5:
del users[key]
filtering_done = False

# Now we do the basic chi2 computation.
Expand Down Expand Up @@ -136,11 +136,11 @@ def get_chi2(tasks, exclude_workers=set()):
# in order to be able to deal accurately with very low p-values.
res_z = scipy.stats.norm.isf(scipy.stats.chi2.sf(adj_row_chi2, columns - 1))

for idx in range(len(keys)):
for idx, key in enumerate(users.keys()):
# We cap the standard normal "residuals" at zero since negative values
# do not look very nice and moreover they do not convey any
# information.
users[keys[idx]] = max(0, res_z[idx])
users[key] = max(0, res_z[idx])

# We compute 95% and 99% thresholds using the Bonferroni correction.
# Under the null hypothesis, yellow and red residuals should appear
Expand Down Expand Up @@ -169,22 +169,22 @@ def get_bad_workers(tasks, cached_chi2=None, p=0.001, res=7.0, iters=1):
# If we have an up-to-date result of get_chi2() we can pass
# it as cached_chi2 to avoid needless recomputation.
bad_workers = set()
for _ in range(iters):
if cached_chi2 is None:
chi2 = get_chi2(tasks, exclude_workers=bad_workers)
else:
chi2 = cached_chi2
cached_chi2 = None
for i in range(iters):
chi2 = (
get_chi2(tasks, exclude_workers=bad_workers)
if i > 0 or cached_chi2 is None
else cached_chi2
)
worst_user = {}
residuals = chi2["residual"]
for worker_key in residuals:
if worker_key in bad_workers:
continue
if chi2["p"] < p or residuals[worker_key] > res:
if worst_user == {} or residuals[worker_key] > worst_user["residual"]:
if not worst_user or residuals[worker_key] > worst_user["residual"]:
worst_user["unique_key"] = worker_key
worst_user["residual"] = residuals[worker_key]
if worst_user == {}:
if not worst_user:
break
bad_workers.add(worst_user["unique_key"])

Expand All @@ -194,10 +194,7 @@ def get_bad_workers(tasks, cached_chi2=None, p=0.001, res=7.0, iters=1):
def update_residuals(tasks, cached_chi2=None):
# If we have an up-to-date result of get_chi2() we can pass
# it as cached_chi2 to avoid needless recomputation.
if cached_chi2 is None:
chi2 = get_chi2(tasks)
else:
chi2 = cached_chi2
chi2 = get_chi2(tasks) if cached_chi2 is None else cached_chi2
residuals = chi2["residual"]

for task in tasks:
Expand Down
15 changes: 4 additions & 11 deletions server/fishtest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,10 +1442,7 @@ def tests_view(request):
run = request.rundb.get_run(request.matchdict["id"])
if run is None:
raise exception_response(404)
if "follow" in request.params:
follow = 1
else:
follow = 0
follow = 1 if "follow" in request.params else 0
results = run["results"]
run["results_info"] = format_results(results, run)
run_args = [("id", str(run["_id"]), "")]
Expand Down Expand Up @@ -1508,8 +1505,7 @@ def tests_view(request):
)

if name == "spsa" and value != "-":
iter_local = value["iter"] + 1 # assume at least one completed,
# and avoid division by zero
iter_local = value["iter"] + 1 # start from 1 to avoid division by zero
A = value["A"]
alpha = value["alpha"]
gamma = value["gamma"]
Expand Down Expand Up @@ -1562,17 +1558,14 @@ def tests_view(request):
if task["active"]:
active += 1
cores += task["worker_info"]["concurrency"]
last_updated = task.get(
"last_updated", datetime.min.replace(tzinfo=timezone.utc)
)
task["last_updated"] = last_updated
task.setdefault("last_updated", datetime.min.replace(tzinfo=timezone.utc))

chi2 = get_chi2(run["tasks"])
update_residuals(run["tasks"], cached_chi2=chi2)

try:
show_task = int(request.params.get("show_task", -1))
except:
except ValueError:
show_task = -1
if show_task >= len(run["tasks"]) or show_task < -1:
show_task = -1
Expand Down

0 comments on commit 32af845

Please sign in to comment.