@@ -87,7 +87,7 @@ def run_inference(model_name, models_dir):
87
87
return model_name , p
88
88
89
89
90
- def check_accuracy (model , p ):
90
+ def check_accuracy (model , p , tolerance = 0.001 ):
91
91
#check if the accuracy of the model inference matches with the published numbers
92
92
#Accuracy values picked up from here https://github.com/tensorflow/models/tree/master/research/slim
93
93
accuracy = \
@@ -100,28 +100,29 @@ def check_accuracy(model, p):
100
100
for line in p .splitlines ():
101
101
print (line .decode ())
102
102
if ('eval/Accuracy' .encode () in line ):
103
- top1_accuracy = re .search ("\[(.*?)\]" , line .decode ()).group (1 )
103
+ accuracy = re .split ("eval/Accuracy" , line .decode ())[1 ]
104
+ top1_accuracy = re .search (r'\[(.*)\]' , accuracy ).group (1 )
104
105
#for now we just validate top 1 accuracy, but calculating top5 anyway.
105
106
if ('eval/Recall_5' .encode () in line ):
106
- top5_accuracy = float (
107
- re .search ("\[(.*?)\]" , line . decode () ).group (1 ))
107
+ accuracy = re . split ( "eval/Recall_5" , line . decode ())[ 1 ]
108
+ top5_accuracy = float ( re .search ("\[(.*?)\]" , accuracy ).group (1 ))
108
109
109
110
for i , d in enumerate (data ):
110
111
if (model in data [i ]["model_name" ]):
111
112
# Tolerance check
112
- diff = abs ( float (top1_accuracy ) - float ( data [i ]["accuracy" ]))
113
+ diff = float (data [i ]["accuracy" ]) - float ( top1_accuracy )
113
114
print ('\033 [1m' + '\n Model Accuracy Verification' + '\033 [0m' )
114
- if (diff <= 0.001 ):
115
- print ('\033 [92m' + 'PASS' + '\033 [0m' +
116
- " Functional accuracy " + top1_accuracy +
117
- " is as expected for " + data [i ]["model_name" ])
118
- return True
119
- else :
115
+ if (diff > tolerance ):
120
116
print ('\033 [91m' + 'FAIL' + '\033 [0m' +
121
117
" Functional accuracy " + top1_accuracy +
122
118
" is not as expected for " + data [i ]["model_name" ] +
123
119
"\n Expected accuracy = " + data [i ]["accuracy" ])
124
120
return False
121
+ else :
122
+ print ('\033 [92m' + 'PASS' + '\033 [0m' +
123
+ " Functional accuracy " + top1_accuracy +
124
+ " is as expected for " + data [i ]["model_name" ])
125
+ return True
125
126
126
127
127
128
if __name__ == '__main__' :
@@ -155,5 +156,5 @@ def check_accuracy(model, p):
155
156
try :
156
157
model_name , p = run_inference (args .model_name , models_dir )
157
158
check_accuracy (model_name , p )
158
- except :
159
- print ("Model accuracy verification failed." )
159
+ except Exception as ex :
160
+ print ("Model accuracy verification failed. Exception: %s" % str ( ex ) )
0 commit comments