forked from Tianxiaomo/pytorch-YOLOv4
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4d298fa
commit da35f12
Showing
5 changed files
with
280 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import sys | ||
import onnx | ||
import os | ||
import argparse | ||
import numpy as np | ||
import cv2 | ||
import onnxruntime | ||
|
||
from tool.utils import * | ||
from tool.darknet2onnx import * | ||
|
||
|
||
def main(cfg_file, weight_file, image_path, batch_size): | ||
|
||
# Transform to onnx as specified batch size | ||
transform_to_onnx(cfg_file, weight_file, batch_size) | ||
# Transform to onnx for demo | ||
onnx_path_demo = transform_to_onnx(cfg_file, weight_file, 1) | ||
|
||
session = onnxruntime.InferenceSession(onnx_path_demo) | ||
# session = onnx.load(onnx_path) | ||
print("The model expects input shape: ", session.get_inputs()[0].shape) | ||
|
||
image_src = cv2.imread(image_path) | ||
detect(session, image_src) | ||
|
||
|
||
|
||
def detect(session, image_src): | ||
IN_IMAGE_H = session.get_inputs()[0].shape[2] | ||
IN_IMAGE_W = session.get_inputs()[0].shape[3] | ||
|
||
# Input | ||
resized = cv2.resize(image_src, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR) | ||
img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) | ||
img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32) | ||
img_in = np.expand_dims(img_in, axis=0) | ||
img_in /= 255.0 | ||
print("Shape of the network input: ", img_in.shape) | ||
|
||
# Compute | ||
input_name = session.get_inputs()[0].name | ||
# output, output_exist = session.run(['decoder.output_conv', 'lane_exist.linear2'], {"input.1": image_np}) | ||
|
||
# print(img_in) | ||
|
||
outputs = session.run(None, {input_name: img_in}) | ||
|
||
''' | ||
print(len(outputs)) | ||
print(outputs[0].shape) | ||
print(outputs[1].shape) | ||
print(outputs[2].shape) | ||
print(outputs[3].shape) | ||
print(outputs[4].shape) | ||
print(outputs[5].shape) | ||
''' | ||
|
||
outputs = [ | ||
[outputs[0],outputs[1]], | ||
[outputs[2],outputs[3]], | ||
[outputs[4],outputs[5]] | ||
] | ||
|
||
# print(outputs[2]) | ||
|
||
num_classes = 80 | ||
boxes = post_processing(img_in, 0.5, num_classes, 0.4, outputs) | ||
|
||
if num_classes == 20: | ||
namesfile = 'data/voc.names' | ||
elif num_classes == 80: | ||
namesfile = 'data/coco.names' | ||
else: | ||
namesfile = 'data/names' | ||
|
||
class_names = load_class_names(namesfile) | ||
plot_boxes_cv2(image_src, boxes, savename='predictions_onnx.jpg', class_names=class_names) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
print("Converting to onnx and running demo ...") | ||
if len(sys.argv) == 5: | ||
cfg_file = sys.argv[1] | ||
weight_file = sys.argv[2] | ||
image_path = sys.argv[3] | ||
batch_size = int(sys.argv[4]) | ||
main(cfg_file, weight_file, image_path, batch_size) | ||
else: | ||
print('Please run this way:\n') | ||
print(' python demo_onnx.py <cfgFile> <weightFile> <imageFile> <batchSize>') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import sys | ||
import onnx | ||
import os | ||
import argparse | ||
import numpy as np | ||
import cv2 | ||
import onnxruntime | ||
import torch | ||
|
||
from tool.utils import * | ||
from models import Yolov4 | ||
from demo_darknet2onnx import detect | ||
|
||
|
||
def transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W): | ||
|
||
model = Yolov4(n_classes=n_classes, inference=True) | ||
|
||
pretrained_dict = torch.load(weight_file, map_location=torch.device('cuda')) | ||
model.load_state_dict(pretrained_dict) | ||
|
||
x = torch.randn((batch_size, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True) # .cuda() | ||
|
||
onnx_file_name = "yolov4_{}_3_{}_{}.onnx".format(batch_size, IN_IMAGE_H, IN_IMAGE_W) | ||
|
||
# Export the model | ||
print('Export the onnx model ...') | ||
torch.onnx.export(model, | ||
x, | ||
onnx_file_name, | ||
export_params=True, | ||
opset_version=11, | ||
do_constant_folding=True, | ||
# input_names=['input'], output_names=['output_1', 'output_2', 'output_3'], | ||
dynamic_axes=None) | ||
|
||
print('Onnx model exporting done') | ||
return onnx_file_name | ||
|
||
|
||
|
||
def main(weight_file, image_path, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W): | ||
|
||
# Transform to onnx as specified batch size | ||
transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W) | ||
# Transform to onnx for demo | ||
onnx_path_demo = transform_to_onnx(weight_file, 1, n_classes, IN_IMAGE_H, IN_IMAGE_W) | ||
|
||
session = onnxruntime.InferenceSession(onnx_path_demo) | ||
# session = onnx.load(onnx_path) | ||
print("The model expects input shape: ", session.get_inputs()[0].shape) | ||
|
||
image_src = cv2.imread(image_path) | ||
detect(session, image_src) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
print("Converting to onnx and running demo ...") | ||
if len(sys.argv) == 7: | ||
|
||
weight_file = sys.argv[1] | ||
image_path = sys.argv[2] | ||
batch_size = int(sys.argv[3]) | ||
n_classes = int(sys.argv[4]) | ||
IN_IMAGE_H = int(sys.argv[5]) | ||
IN_IMAGE_W = int(sys.argv[6]) | ||
|
||
main(weight_file, image_path, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W) | ||
else: | ||
print('Please run this way:\n') | ||
print(' python demo_onnx.py <weight_file> <image_path> <batch_size> <n_classes> <IN_IMAGE_H> <IN_IMAGE_W>') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.