Skip to content

Commit

Permalink
Added JSON log to ONNX inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
sovit-123 committed May 22, 2024
1 parent 8c2fe31 commit f97a888
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def main(args):
# Save JSON log file.
if args['log_json']:
log_json.save(os.path.join(OUT_DIR, 'log.json'))

# Calculate and print the average FPS.
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
Expand Down
11 changes: 9 additions & 2 deletions onnx_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from utils.annotations import (
inference_annotations, convert_detections
)
# from utils.logging import log_to_json
from utils.logging import LogJSON

def collect_all_images(dir_test):
"""
Expand Down Expand Up @@ -136,6 +136,8 @@ def main(args):
# score below this will be discarded.
detection_threshold = args['threshold']

if args['log_json']:
log_json = LogJSON(os.path.join(OUT_DIR, 'log.json'))

# To count the total number of frames iterated through.
frame_count = 0
Expand Down Expand Up @@ -178,7 +180,7 @@ def main(args):

# Log to JSON?
if args['log_json']:
log_to_json(orig_image, os.path.join(OUT_DIR, 'log.json'), outputs)
log_json.update(orig_image, image_name, outputs[0], CLASSES)

# Carry further only if there are detected boxes.
if len(outputs[0]['boxes']) != 0:
Expand Down Expand Up @@ -208,6 +210,11 @@ def main(args):

print('TEST PREDICTIONS COMPLETE')
cv2.destroyAllWindows()

# Save JSON log file.
if args['log_json']:
log_json.save(os.path.join(OUT_DIR, 'log.json'))

# Calculate and print the average FPS.
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
Expand Down
11 changes: 9 additions & 2 deletions onnx_inference_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from utils.transforms import infer_transforms, resize
from deep_sort_realtime.deepsort_tracker import DeepSort
# from utils.logging import log_to_json
from utils.logging import LogJSON

def read_return_video_data(video_path):
cap = cv2.VideoCapture(video_path)
Expand Down Expand Up @@ -148,6 +148,9 @@ def main(args):
else:
RESIZE_TO = frame_width

if args['log_json']:
log_json = LogJSON(os.path.join(OUT_DIR, 'log.json'))

frame_count = 0 # To count total frames.
total_fps = 0 # To get the final frames per second.

Expand Down Expand Up @@ -185,7 +188,7 @@ def main(args):

# Log to JSON?
if args['log_json']:
log_to_json(frame, os.path.join(OUT_DIR, 'log.json'), outputs)
log_json.update(frame, save_name, outputs[0], CLASSES)

# Carry further only if there are detected boxes.
if len(outputs[0]['boxes']) != 0:
Expand Down Expand Up @@ -234,6 +237,10 @@ def main(args):
# Close all frames and video windows.
cv2.destroyAllWindows()

# Save JSON log file.
if args['log_json']:
log_json.save(os.path.join(OUT_DIR, 'log.json'))

# Calculate and print the average FPS.
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
Expand Down

0 comments on commit f97a888

Please sign in to comment.