From 83db1e0f7d1ffa9621474c73352a362640e8250c Mon Sep 17 00:00:00 2001 From: gyzhou2000 Date: Thu, 4 Jul 2024 16:07:20 +0800 Subject: [PATCH] remove some code --- examples/dna/dna_trainer.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/examples/dna/dna_trainer.py b/examples/dna/dna_trainer.py index ce253dfa..d08167b4 100644 --- a/examples/dna/dna_trainer.py +++ b/examples/dna/dna_trainer.py @@ -115,7 +115,6 @@ def main(args): test_y = tlx.gather(data['y'], data['test_idx']) test_acc = calculate_acc(test_logits, test_y, metrics) print("Test acc: {:.4f}".format(test_acc)) - return test_acc if __name__ == '__main__': @@ -142,15 +141,4 @@ def main(args): else: tlx.set_device("CPU") - # main(args) - - test_accs = [] - for i in range(5): - print(f"Run {i+1}:") - test_acc = main(args) - test_accs.append(test_acc) - - mean_test_acc = np.mean(test_accs) - std_test_acc = np.std(test_accs) - print(f"Mean Test Accuracy: {mean_test_acc:.4f}") - print(f"Standard Deviation of Test Accuracy: {std_test_acc:.4f}") + main(args)