From 1ab6a9166f2d92c528459930a606df71fca077db Mon Sep 17 00:00:00 2001 From: leondavi Date: Fri, 10 Nov 2023 23:12:05 +0200 Subject: [PATCH] [TEST] Increase accepted marginal error --- src_py/apiServer/experiment_flow_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src_py/apiServer/experiment_flow_test.py b/src_py/apiServer/experiment_flow_test.py index 12074f6e..0c8ffdc5 100644 --- a/src_py/apiServer/experiment_flow_test.py +++ b/src_py/apiServer/experiment_flow_test.py @@ -5,6 +5,8 @@ from logger import * from stats import Stats +TEST_ACCEPTABLE_MARGIN_OF_ERROR = 0.02 + def print_test(in_str : str): PREFIX = "[NERLNET-TEST] " LOG_INFO(f"{PREFIX} {in_str}") @@ -62,7 +64,7 @@ def print_test(in_str : str): exp_stats = Stats(experiment_inst) data = exp_stats.get_loss_min() -print("min loss of each worker") +print_test("min loss of each worker") print(data) conf = exp_stats.get_confusion_matrices() @@ -73,7 +75,9 @@ def print_test(in_str : str): for j in acc_stats[worker].keys(): diff = abs(acc_stats[worker][j]["F1"] - baseline_acc_stats[worker][str(j)]["F1"]) diff_from_baseline.append(diff/baseline_acc_stats[worker][str(j)]["F1"]) -anomaly_detected = not all([x < 0.01 for x in diff_from_baseline]) +anomaly_detected = not all([x < TEST_ACCEPTABLE_MARGIN_OF_ERROR for x in diff_from_baseline]) if anomaly_detected: + print_test("Anomaly failure detected") + print_test(f"diff_from_baseline: {diff_from_baseline}") exit(1)