Skip to content

Commit

Permalink
delete onnxruntime suport
Browse files Browse the repository at this point in the history
  • Loading branch information
yuki committed Mar 5, 2024
1 parent f6e1d6c commit c8526e8
Showing 1 changed file with 10 additions and 35 deletions.
45 changes: 10 additions & 35 deletions image_classification/volo/volo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

import ailia
import onnxruntime
import volo_labels

# import original modules
Expand Down Expand Up @@ -33,16 +32,11 @@
# Arguemnt Parser Config
# ======================
parser = get_base_parser('Image classification model: VOLO', IMAGE_PATH, None)
parser.add_argument(
'--onnx',
action='store_true',
help='execute onnxruntime version.'
)
parser.add_argument(
'--arch',
default="volo_d1_224",
type=str,
choices=["volo_d1_224" ,"volo_d1_384","volo_d2_224", "volo_d2_384","volo_d3_224","volo_d3_448", "volo_d4_224", "volo_d4_448" ,"volo_d5_224"]
choices=["volo_d1_224" ,"volo_d1_384","volo_d2_224", "volo_d2_384","volo_d3_224","volo_d3_448", "volo_d4_224", "volo_d4_448" ,"volo_d5_224","volo_d5_448","volo_d5_512"]

)
args = update_parser(parser)
Expand All @@ -55,8 +49,8 @@
MODEL_PATH = args.arch + '.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/volo/'

IMAGE_HEIGHT = int(args.arch.split("_")[-1])
IMAGE_WIDTH = int(args.arch.split("_")[-1])
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
# ======================
# Main functions
# ======================
Expand Down Expand Up @@ -86,11 +80,7 @@ def resize_and_center_crop(image, input_size, crop_pct):

def recognize_from_image():
# net initialize
if args.onnx:
session = onnxruntime.InferenceSession(WEIGHT_PATH)
first_input_name = session.get_inputs()[0].name
else:
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)

# input image loop
for image_path in args.input:
Expand All @@ -110,19 +100,13 @@ def recognize_from_image():
logger.info('BENCHMARK mode')
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
if args.onnx:
results = session.run([], {first_input_name: input_data})[0]
preds_ailia=np.array(results)
else:
preds_ailia = net.predict(input_data)

preds_ailia = net.predict(input_data)

end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
else:
if args.onnx:
results = session.run([], {first_input_name: input_data})[0]
preds_ailia=np.array(results)
else:
preds_ailia = net.predict(input_data)
preds_ailia = net.predict(input_data)

# postprocessing
print_results(preds_ailia, volo_labels.imagenet_category)
Expand All @@ -131,12 +115,7 @@ def recognize_from_image():

def recognize_from_video():
# net initialize

if args.onnx:
session = onnxruntime.InferenceSession(WEIGHT_PATH)
first_input_name = session.get_inputs()[0].name
else:
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)

capture = webcamera_utils.get_capture(args.video)

Expand All @@ -162,11 +141,7 @@ def recognize_from_video():
input_data = input_data.astype(np.float32)

# inference
if args.onnx:
results = session.run([], {first_input_name: input_data})[0]
preds_ailia=np.array(results)
else:
preds_ailia = net.predict(input_data)
preds_ailia = net.predict(input_data)

# postprocessing
plot_results(frame, preds_ailia, volo_labels.imagenet_category)
Expand Down

0 comments on commit c8526e8

Please sign in to comment.