Skip to content

Commit

Permalink
save detection result to json for detic
Browse files Browse the repository at this point in the history
  • Loading branch information
mmitti committed Feb 11, 2025
1 parent db64a8f commit 1f159af
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions object_detection/detic/detic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import cv2
import json
from PIL import Image

import ailia
Expand Down Expand Up @@ -83,6 +84,14 @@
default=800, type=int, # tempolary limit to 800px (original : 1333)
help='The detection width for detic. (default: 800)'
)
parser.add_argument(
'-w', '--write_prediction',
nargs='?',
const='txt',
choices=['txt', 'json'],
type=str,
help='Output results to txt or json file.'
)
args = update_parser(parser)


Expand Down Expand Up @@ -268,7 +277,43 @@ def draw_predictions(img, predictions):

return img

def write_predictions(file_name, predictions, file_type='txt'):
vocabulary = args.vocabulary
boxes = predictions["pred_boxes"].astype(np.int64)
scores = predictions["scores"]
classes = predictions["pred_classes"].tolist()

class_names = (
get_lvis_meta_v1() if vocabulary == 'lvis' else get_in21k_meta_v1()
)["thing_classes"]
# labels = [class_names[i] for i in classes] # onnx runtime
labels = [class_names[int(i)] for i in classes] # ailia always returns float tensor so need to add cast
num_instances = len(boxes)

if file_type == "txt":
with open(file_name, 'w') as f:
for i in range(num_instances):
x0, y0, x1, y1 = boxes[i]
label = labels[i]
score = scores[i]
f.write('%s %f %d %d %d %d\n' % (
label.replace(' ', '_'),
score,
int(x0), int(y0),
int(x1 - x0), int(y1 - y0),
))
else:
results = []
for i in range(num_instances):
x0, y0, x1, y1 = boxes[i]
label = labels[i]
score = scores[i]
results.append({
'category': label, 'score': score,
'x': x0, 'y': y0, 'w': x1 - x0, 'h': y1 - y0
})
with open(file_name, 'w') as f:
json.dump(results, f, indent=2)
# ======================
# Main functions
# ======================
Expand Down Expand Up @@ -413,6 +458,12 @@ def recognize_from_image(net):
savepath = get_savepath(args.savepath, image_path, ext='.png')
logger.info(f'saved at : {savepath}')
cv2.imwrite(savepath, res_img)

# write_prediction
if args.write_prediction is not None:
ext = args.write_prediction
pred_file = "%s.%s" % (savepath.rsplit('.', 1)[0], ext)
write_predictions(pred_file, pred, file_type=ext)

logger.info('Script finished successfully.')

Expand Down

0 comments on commit 1f159af

Please sign in to comment.