Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit fe3cb21

Browse files
sindhu-nervanaavijit-nervana
authored andcommitted
update regex, tolerance to verify accuracy (#528)
* update regex, tolerance to verify accuracy * chris's comments
1 parent b4cb842 commit fe3cb21

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

diagnostics/model_accuracy/verify_inference_model.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def run_inference(model_name, models_dir):
8787
return model_name, p
8888

8989

90-
def check_accuracy(model, p):
90+
def check_accuracy(model, p, tolerance=0.001):
9191
#check if the accuracy of the model inference matches with the published numbers
9292
#Accuracy values picked up from here https://github.com/tensorflow/models/tree/master/research/slim
9393
accuracy = \
@@ -100,28 +100,29 @@ def check_accuracy(model, p):
100100
for line in p.splitlines():
101101
print(line.decode())
102102
if ('eval/Accuracy'.encode() in line):
103-
top1_accuracy = re.search("\[(.*?)\]", line.decode()).group(1)
103+
accuracy = re.split("eval/Accuracy", line.decode())[1]
104+
top1_accuracy = re.search(r'\[(.*)\]', accuracy).group(1)
104105
#for now we just validate top 1 accuracy, but calculating top5 anyway.
105106
if ('eval/Recall_5'.encode() in line):
106-
top5_accuracy = float(
107-
re.search("\[(.*?)\]", line.decode()).group(1))
107+
accuracy = re.split("eval/Recall_5", line.decode())[1]
108+
top5_accuracy = float(re.search("\[(.*?)\]", accuracy).group(1))
108109

109110
for i, d in enumerate(data):
110111
if (model in data[i]["model_name"]):
111112
# Tolerance check
112-
diff = abs(float(top1_accuracy) - float(data[i]["accuracy"]))
113+
diff = float(data[i]["accuracy"]) - float(top1_accuracy)
113114
print('\033[1m' + '\nModel Accuracy Verification' + '\033[0m')
114-
if (diff <= 0.001):
115-
print('\033[92m' + 'PASS' + '\033[0m' +
116-
" Functional accuracy " + top1_accuracy +
117-
" is as expected for " + data[i]["model_name"])
118-
return True
119-
else:
115+
if (diff > tolerance):
120116
print('\033[91m' + 'FAIL' + '\033[0m' +
121117
" Functional accuracy " + top1_accuracy +
122118
" is not as expected for " + data[i]["model_name"] +
123119
"\nExpected accuracy = " + data[i]["accuracy"])
124120
return False
121+
else:
122+
print('\033[92m' + 'PASS' + '\033[0m' +
123+
" Functional accuracy " + top1_accuracy +
124+
" is as expected for " + data[i]["model_name"])
125+
return True
125126

126127

127128
if __name__ == '__main__':
@@ -155,5 +156,5 @@ def check_accuracy(model, p):
155156
try:
156157
model_name, p = run_inference(args.model_name, models_dir)
157158
check_accuracy(model_name, p)
158-
except:
159-
print("Model accuracy verification failed.")
159+
except Exception as ex:
160+
print("Model accuracy verification failed. Exception: %s" % str(ex))

0 commit comments

Comments
 (0)