-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
executable file
·72 lines (60 loc) · 2.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
from fastapi.testclient import TestClient
import uvicorn
import json
import glob
import tqdm
import random
import argparse
from app import config, message
from app.handler.controller.outer import fastapi_app
def start_mode():
app = fastapi_app()
uvicorn.run(app, host=config.API_HOST, port=config.API_PORT)
def accuracy_mode():
# load dajare samples
data: list[dict] = []
files: list[str] = glob.glob(config.DATA_FILE_PATH)
for file_name in tqdm.tqdm(files):
print(message.LOAD_FILE_MSG(file_name))
with open(file_name, 'r') as f:
data.extend(json.load(f))
# specify the number of samples to use
default_samples: int = len(data)
input_str = input(message.N_SAMPLES_INPUT_GUIDE(default_samples, len(data))) or default_samples
n_samples: int = int(input_str)
data = random.sample(data, n_samples)
# launch API
app = TestClient(fastapi_app())
# measure accuracy
error_samples: list[dict] = []
print(message.MEASURE_ACCURACY_MSG(n_samples))
for sample in tqdm.tqdm(data):
judge_res = app.get('judge', params={'dajare': sample['dajare']}).json()
reading_res = app.get('reading', params={'dajare': sample['dajare']}).json()
if judge_res['is_dajare'] != sample['is_dajare']:
error_samples.append({
'dajare': sample['dajare'],
'reading': reading_res['reading'],
'is_dajare': sample['is_dajare'],
'judge_result': judge_res['is_dajare'],
'applied_rule': judge_res['applied_rule'],
})
print(message.ACCURACY_MSG((n_samples - len(error_samples)) / n_samples))
# dump error samples
with open(config.DATA_ERROR_FILE_PATH, 'w') as f:
f.write(json.dumps(error_samples, ensure_ascii=False, indent=2))
if __name__ == '__main__':
# options
parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument('-s', '--start',
help='start app',
action='store_true')
parser.add_argument('-a', '--accuracy',
help='measure accuracy',
action='store_true')
args = parser.parse_args()
if args.start:
start_mode()
if args.accuracy:
accuracy_mode()