Skip to content

Commit

Permalink
all tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Jul 10, 2024
1 parent 4a219cf commit 82504cb
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 215 deletions.
44 changes: 19 additions & 25 deletions mlscorecheck/auc/_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
'auc_from_sens_spec',
'acc_from_auc',
'auc_from_sens_spec_kfold',
'generate_sens_spec_acc_problem',
'generate_kfold_sens_spec_acc_problem',
#'generate_sens_spec_acc_problem',
#'generate_kfold_sens_spec_acc_problem',
'generate_average',
'generate_kfold_sens_spec_fix_problem',
'generate_kfold_acc_fix_problem',
'generate_kfold_auc_fix_problem',
'R'
]

Expand All @@ -32,15 +30,24 @@ def generate_average(avg_value, n_items, lower_bound=None, random_state=None):

indices = list(range(n_items))

for _ in range(n_items*2):
for _ in range(n_items*10):
a, b = random_state.choice(indices, 2, replace=False)
dist = min(values[a], 1 - values[a], values[b], 1 - values[b])
d = random_state.random() * dist
values[a] += d
values[b] -= d
if random_state.randint(2) == 0:
dist = min(values[a], 1 - values[a], values[b], 1 - values[b])
d = random_state.random() * dist

if lower_bound is not None and values[b] - d < lower_bound:
d = values[b] - lower_bound
values[a] += d
values[b] -= d
else:
mean = (values[a] + values[b]) / 2
values[a] = (values[a] + mean) / 2
values[b] = (values[b] + mean) / 2

return values.astype(float)

"""
def generate_sens_spec_acc_problem(
*,
n_swaps : float = 0.3,
Expand Down Expand Up @@ -123,7 +130,9 @@ def generate_sens_spec_acc_problem(
}
return result
"""

"""
def generate_kfold_sens_spec_acc_problem(
n_folds : int | None = None,
n_swaps : float = 0.3,
Expand Down Expand Up @@ -168,6 +177,7 @@ def generate_kfold_sens_spec_acc_problem(
}
return results
"""

def generate_kfold_sens_spec_fix_problem(
*,
Expand All @@ -181,22 +191,6 @@ def generate_kfold_sens_spec_fix_problem(
return {'sens': generate_average(sens, k, sens_lower_bound, random_state),
'spec': generate_average(spec, k, spec_lower_bound, random_state)}

def generate_kfold_acc_fix_problem(
acc,
k,
acc_lower_bound = None,
random_state = None
):
return {'acc': generate_average(acc, k, acc_lower_bound, random_state)}

def generate_kfold_auc_fix_problem(
auc,
k,
auc_lower_bound = None,
random_state = None
):
return {'auc': generate_average(auc, k, auc_lower_bound, random_state)}

def prepare_intervals_for_auc_estimation(
scores: dict,
eps: float,
Expand Down
Loading

0 comments on commit 82504cb

Please sign in to comment.