From 81fbd54c9d2677d66caf65f1479d67a50e19899d Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Fri, 14 Apr 2023 10:31:36 +0800 Subject: [PATCH] [Other] Add tests for TIMVX (#1605) * add tests for timvx * add mobilenetv1 test * update code * fix log info * update log * fix test --------- Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> --- cmake/toolchain.cmake | 2 +- tests/CMakeLists.txt | 3 + tests/timvx/CMakeLists.txt | 46 +++++++ tests/timvx/common.h | 242 +++++++++++++++++++++++++++++++++ tests/timvx/download_models.py | 165 ++++++++++++++++++++++ tests/timvx/models_url.txt | 8 ++ tests/timvx/prepare.sh | 4 + tests/timvx/requirements.txt | 1 + tests/timvx/run_test.sh | 9 ++ tests/timvx/test_clas.cc | 68 +++++++++ tests/timvx/test_ppliteseg.cc | 76 +++++++++++ tests/timvx/test_ppyoloe.cc | 68 +++++++++ tests/timvx/test_yolov5.cc | 71 ++++++++++ 13 files changed, 762 insertions(+), 1 deletion(-) mode change 100644 => 100755 tests/CMakeLists.txt create mode 100755 tests/timvx/CMakeLists.txt create mode 100755 tests/timvx/common.h create mode 100755 tests/timvx/download_models.py create mode 100755 tests/timvx/models_url.txt create mode 100755 tests/timvx/prepare.sh create mode 100755 tests/timvx/requirements.txt create mode 100755 tests/timvx/run_test.sh create mode 100755 tests/timvx/test_clas.cc create mode 100755 tests/timvx/test_ppliteseg.cc create mode 100755 tests/timvx/test_ppyoloe.cc create mode 100755 tests/timvx/test_yolov5.cc diff --git a/cmake/toolchain.cmake b/cmake/toolchain.cmake index d06416b406..f05157fdf0 100755 --- a/cmake/toolchain.cmake +++ b/cmake/toolchain.cmake @@ -35,7 +35,7 @@ if (DEFINED TARGET_ABI) if(WITH_TIMVX) set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-aarch64-timvx-20230316.tgz") else() - set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-arm64-20221209.tgz") + set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-arm64-20230316.tgz") endif() set(THIRD_PARTY_PATH ${CMAKE_CURRENT_BINARY_DIR}/third_libs) set(OpenCV_DIR ${THIRD_PARTY_PATH}/install/opencv/lib/cmake/opencv4) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt old mode 100644 new mode 100755 index 74b7e8e52b..5a787c5021 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -72,6 +72,9 @@ if(WITH_TESTING) message(STATUS "") message(STATUS "*************FastDeploy Unittest Summary**********") file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc) + + file(GLOB_RECURSE TIMVX_SRCS ${PROJECT_SOURCE_DIR}/tests/timvx/test_*.cc) + list(REMOVE_ITEM ALL_TEST_SRCS ${TIMVX_SRCS}) if(NOT ENABLE_VISION) # vision_preprocess and release_task need vision file(GLOB_RECURSE VISION_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/vision_preprocess/test_*.cc) diff --git a/tests/timvx/CMakeLists.txt b/tests/timvx/CMakeLists.txt new file mode 100755 index 0000000000..9a80aab175 --- /dev/null +++ b/tests/timvx/CMakeLists.txt @@ -0,0 +1,46 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) +include_directories(${FastDeploy_INCLUDE_DIRS}) + +set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/timvx_tests) + +# add test for yolov5 +add_executable(test_yolov5 ${PROJECT_SOURCE_DIR}/test_yolov5.cc) +target_link_libraries(test_yolov5 ${FASTDEPLOY_LIBS}) +install(TARGETS test_yolov5 DESTINATION ./) + +# add test for ppyoloe +add_executable(test_ppyoloe ${PROJECT_SOURCE_DIR}/test_ppyoloe.cc) +target_link_libraries(test_ppyoloe ${FASTDEPLOY_LIBS}) +install(TARGETS test_ppyoloe DESTINATION ./) + +# add test for paddleclas +add_executable(test_clas ${PROJECT_SOURCE_DIR}/test_clas.cc) +target_link_libraries(test_clas ${FASTDEPLOY_LIBS}) +install(TARGETS test_clas DESTINATION ./) + +# add test for pp-liteseg +add_executable(test_ppliteseg ${PROJECT_SOURCE_DIR}/test_ppliteseg.cc) +target_link_libraries(test_ppliteseg ${FASTDEPLOY_LIBS}) +install(TARGETS test_ppliteseg DESTINATION ./) + + +install(DIRECTORY models DESTINATION ./) +install(DIRECTORY images DESTINATION ./) +install(DIRECTORY results DESTINATION ./) + +file(GLOB RUN_TEST run_test.sh) +install(PROGRAMS ${RUN_TEST} DESTINATION ./) + +file(GLOB_RECURSE FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/lib*.so*) +file(GLOB_RECURSE ALL_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/lib*.so*) +list(APPEND ALL_LIBS ${FASTDEPLOY_LIBS}) +install(PROGRAMS ${ALL_LIBS} DESTINATION lib) diff --git a/tests/timvx/common.h b/tests/timvx/common.h new file mode 100755 index 0000000000..3962e055f2 --- /dev/null +++ b/tests/timvx/common.h @@ -0,0 +1,242 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include "fastdeploy/vision.h" + +std::vector stringSplit(const std::string& str, char delim) { + std::stringstream ss(str); + std::string item; + std::vector elems; + while (std::getline(ss, item, delim)) { + if (!item.empty()) { + elems.push_back(item); + } + } + return elems; +} + + +bool CompareDetResult(const fastdeploy::vision::DetectionResult& res, + const std::string& det_result_file) { + std::ifstream res_str(det_result_file); + if (!res_str.is_open()) { + std::cout<< "Could not open detect result file : " + << det_result_file <<"\n"<< std::endl; + return false; + } + int obj_num = 0; + while (!res_str.eof()) { + std::string line; + std::getline(res_str, line); + if (line.find("DetectionResult") == line.npos + && line.find(",") != line.npos ) { + auto strs = stringSplit(line, ','); + if (strs.size() != 6) { + std::cout<< "Failed to parse result file : " + << det_result_file <<"\n"<< std::endl; + return false; + } + std::vector vals; + for (auto str : strs) { + vals.push_back(atof(str.c_str())); + } + if (abs(res.scores[obj_num] - vals[4]) > 0.3) { + std::cout<< "Score error, the result is: " + << res.scores[obj_num] << " but the expected is: " + << vals[4] << std::endl; + return false; + } + if (abs(res.label_ids[obj_num] - vals[5]) > 0) { + std::cout<< "label error, the result is: " + << res.label_ids[obj_num] << " but the expected is: " + << vals[5] < boxes = res.boxes[obj_num++]; + for (auto i = 0; i < 4; i++) { + if (abs(boxes[i] - vals[i]) > 5) { + std::cout<< "position error, the result is: " + << boxes[i] << " but the expected is: " << vals[i] <(atof(strs[1].c_str())); + if (res.label_ids[obj_num] != label) { + std::cout<< "label error, the result is: " + << res.label_ids[obj_num] << " but the expected is: " + << label<< "\n" << std::endl; + return false; + } + } else if (line.find("scores") != line.npos + && line.find(":") != line.npos) { + auto strs = stringSplit(line, ':'); + if (strs.size() != 2) { + std::cout<< "Failed to parse result file : " + << cls_result_file << "\n" << std::endl; + return false; + } + float score = atof(strs[1].c_str()); + if (abs(res.scores[obj_num] - score) > 1e-1) { + std::cout << "score error, the result is: " + << res.scores[obj_num] << " but the expected is: " + << score << "\n" << std::endl; + return false; + } else { + obj_num++; + } + } else if (line.size()) { + std::cout << "Unknown File. \n" << std::endl; + return false; + } + } + return true; +} + +bool WriteSegResult(const fastdeploy::vision::SegmentationResult& res, + const std::string& seg_result_file) { + std::ofstream res_str(seg_result_file); + if (!res_str.is_open()) { + std::cerr<< "Could not open segmentation result file : " + << seg_result_file <<" to write.\n"<< std::endl; + return false; + } + std::string out; + out = ""; + // save shape + for (auto shape : res.shape) { + out += std::to_string(shape) + ","; + } + out += "\n"; + // save label + for (auto label : res.label_map) { + out += std::to_string(label) + ","; + } + out += "\n"; + // save score + if (res.contain_score_map) { + for (auto score : res.score_map) { + out += std::to_string(score) + ","; + } + } + res_str << out; + return true; +} + +bool CompareSegResult(const fastdeploy::vision::SegmentationResult& res, + const std::string& seg_result_file) { + std::ifstream res_str(seg_result_file); + if (!res_str.is_open()) { + std::cout<< "Could not open detect result file : " + << seg_result_file <<"\n"<< std::endl; + return false; + } + std::string line; + std::getline(res_str, line); + if (line.find(",") == line.npos) { + std::cout << "Unexpected File." << std::endl; + return false; + } + // check shape diff + auto shape_strs = stringSplit(line, ','); + std::vector shape; + for (auto str : shape_strs) { + shape.push_back(static_cast(atof(str.c_str()))); + } + if (shape.size() != res.shape.size()) { + std::cout << "Output shape and expected shape size mismatch, shape size: " + << res.shape.size() << " expected shape size: " + << shape.size() << std::endl; + return false; + } + for (auto i = 0; i < res.shape.size(); i++) { + if (res.shape[i] != shape[i]) { + std::cout << "Output Shape and expected shape mismatch, shape: " + << res.shape[i] << " expected: " << shape[i] << std::endl; + return false; + } + } + std::cout << "Shape check passed!" << std::endl; + + std::getline(res_str, line); + if (line.find(",") == line.npos) { + std::cout << "Unexpected File." << std::endl; + return false; + } + // check label + auto label_strs = stringSplit(line, ','); + std::vector labels; + for (auto str : label_strs) { + labels.push_back(static_cast(atof(str.c_str()))); + } + if (labels.size() != res.label_map.size()) { + std::cout << "Output labels and expected shape size mismatch." << std::endl; + return false; + } + for (auto i = 0; i < res.label_map.size(); i++) { + if (res.label_map[i] != labels[i]) { + std::cout << "Output labels and expected labels mismatch." << std::endl; + return false; + } + } + std::cout << "Label check passed!" << std::endl; + + // check score_map + if (res.contain_score_map) { + auto scores_strs = stringSplit(line, ','); + std::vector scores; + for (auto str : scores_strs) { + scores.push_back(static_cast(atof(str.c_str()))); + } + if (scores.size() != res.score_map.size()) { + std::cout << "Output scores and expected score_map size mismatch." + < 3e-1) { + std::cout << "Output scores and expected scores mismatch." + << std::endl; + return false; + } + } + } + return true; +} diff --git a/tests/timvx/download_models.py b/tests/timvx/download_models.py new file mode 100755 index 0000000000..c49c360cb8 --- /dev/null +++ b/tests/timvx/download_models.py @@ -0,0 +1,165 @@ +import os +import os.path as osp +import logging +import requests +import shutil +import zipfile +import tarfile +import hashlib +import tqdm + +DOWNLOAD_RETRY_LIMIT = 3 + + +def md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logging.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logging.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def move_and_merge_tree(src, dst): + """ + Move src directory to dst, if dst is already exists, + merge src to dst + """ + if not osp.exists(dst): + shutil.move(src, dst) + else: + if not osp.isdir(src): + shutil.move(src, dst) + return + for fp in os.listdir(src): + src_fp = osp.join(src, fp) + dst_fp = osp.join(dst, fp) + if osp.isdir(src_fp): + if osp.isdir(dst_fp): + move_and_merge_tree(src_fp, dst_fp) + else: + shutil.move(src_fp, dst_fp) + elif osp.isfile(src_fp) and \ + not osp.isfile(dst_fp): + shutil.move(src_fp, dst_fp) + + +def download(url, path, rename=None, md5sum=None, show_progress=False): + """ + Download from url, save to path. + url (str): download url + path (str): download to given path + """ + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + if rename is not None: + fullname = osp.join(path, rename) + retry_cnt = 0 + while not (osp.exists(fullname) and md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + logging.debug("{} download failed.".format(fname)) + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + logging.info("Downloading {} from {}".format(fname, url)) + + req = requests.get(url, stream=True) + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size and show_progress: + for chunk in tqdm.tqdm( + req.iter_content(chunk_size=1024), + total=(int(total_size) + 1023) // 1024, + unit='KB'): + f.write(chunk) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + logging.debug("{} download completed.".format(fname)) + + return fullname + + +def decompress(fname): + """ + Decompress for zip and tar file + """ + logging.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + fpath = osp.split(fname)[0] + fpath_tmp = osp.join(fpath, 'tmp') + if osp.isdir(fpath_tmp): + shutil.rmtree(fpath_tmp) + os.makedirs(fpath_tmp) + + if fname.find('.tar') >= 0 or fname.find('.tgz') >= 0: + with tarfile.open(fname) as tf: + tf.extractall(path=fpath_tmp) + elif fname.find('.zip') >= 0: + with zipfile.ZipFile(fname) as zf: + zf.extractall(path=fpath_tmp) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + for f in os.listdir(fpath_tmp): + src_dir = osp.join(fpath_tmp, f) + dst_dir = osp.join(fpath, f) + move_and_merge_tree(src_dir, dst_dir) + + shutil.rmtree(fpath_tmp) + logging.debug("{} decompressed.".format(fname)) + return dst_dir + + +def download_and_decompress(url, path='.', rename=None): + full_name = download(url, path, rename) + if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0: + return decompress(full_name) + return + + +def unset_env(key): + del os.environ[key] + + +if __name__ == '__main__': + with open("models_url.txt", "r") as f: + if 'https_proxy' in os.environ or 'http_proxy' in os.environ: + unset_env("https_proxy") + unset_env("http_proxy") + for line in f.readlines(): + url = line.strip() + print("Downloading: ", url) + if line.count(".tgz") > 0 or line.count(".tar") > 0 or line.count( + "zip") > 0: + dst_dir = download_and_decompress(url, "./models") + else: + dst_dir = download(url, "./images", None) diff --git a/tests/timvx/models_url.txt b/tests/timvx/models_url.txt new file mode 100755 index 0000000000..86d0c0ddec --- /dev/null +++ b/tests/timvx/models_url.txt @@ -0,0 +1,8 @@ +https://bj.bcebos.com/paddlehub/fastdeploy/mobilenetv1_ssld_ptq.tar +https://bj.bcebos.com/paddlehub/fastdeploy/resnet50_vd_ptq.tar +https://bj.bcebos.com/fastdeploy/models/yolov5s_ptq_model.tar.gz +https://bj.bcebos.com/fastdeploy/models/ppyoloe_noshare_qat.tar.gz +https://bj.bcebos.com/fastdeploy/models/rk1/ppliteseg.tar.gz +https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg +https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg +https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png diff --git a/tests/timvx/prepare.sh b/tests/timvx/prepare.sh new file mode 100755 index 0000000000..958f314b76 --- /dev/null +++ b/tests/timvx/prepare.sh @@ -0,0 +1,4 @@ +python download_models.py + +wget https://bj.bcebos.com/fastdeploy/models/results.tar.gz +tar -xf results.tar.gz && rm -rf results.tar.gz diff --git a/tests/timvx/requirements.txt b/tests/timvx/requirements.txt new file mode 100755 index 0000000000..78620c472c --- /dev/null +++ b/tests/timvx/requirements.txt @@ -0,0 +1 @@ +tqdm diff --git a/tests/timvx/run_test.sh b/tests/timvx/run_test.sh new file mode 100755 index 0000000000..244cdcf8c2 --- /dev/null +++ b/tests/timvx/run_test.sh @@ -0,0 +1,9 @@ +export LD_LIBRARY_PATH=${PWD}/lib +export VIV_VX_ENABLE_GRAPH_TRANSFORM=-pcq:1 +export VIV_VX_SET_PER_CHANNEL_ENTROPY=100 + +./test_clas models/mobilenetv1_ssld_ptq images/ILSVRC2012_val_00000010.jpeg results/mobilenetv1_cls.txt +./test_clas models/resnet50_vd_ptq/ images/ILSVRC2012_val_00000010.jpeg results/resnet50_cls.txt +./test_yolov5 models/yolov5s_ptq_model/ images/000000014439.jpg results/yolov5_result.txt +./test_ppyoloe models/ppyoloe_noshare_qat/ images/000000014439.jpg results/ppyoloe_result.txt +./test_ppliteseg models/ppliteseg images/cityscapes_demo.png results/ppliteseg_result.txt diff --git a/tests/timvx/test_clas.cc b/tests/timvx/test_clas.cc new file mode 100755 index 0000000000..cd9715ca14 --- /dev/null +++ b/tests/timvx/test_clas.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "common.h" +#include "fastdeploy/vision.h" +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void InitAndInfer(const std::string& model_dir, const std::string& image_file, + const std::string& cls_result) { + auto model_file = model_dir + sep + "inference.pdmodel"; + auto params_file = model_dir + sep + "inference.pdiparams"; + auto config_file = model_dir + sep + "inference_cls.yaml"; + fastdeploy::vision::EnableFlyCV(); + fastdeploy::RuntimeOption option; + option.UseTimVX(); + + auto model = fastdeploy::vision::classification::PaddleClasModel( + model_file, params_file, config_file, option); + + assert(model.Initialized()); + + auto im = cv::imread(image_file); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + if (CompareClsResult(res, cls_result)) { + std::cout << model_dir + " run successfully." << std::endl; + } else { + std::cerr << model_dir + " run failed." << std::endl; + } +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: test_clas path/to/quant_model " + "path/to/image " + "e.g ./test_clas ./ResNet50_vd_quant ./test.jpeg resnet50_clas.txt" + << std::endl; + return -1; + } + + std::string model_dir = argv[1]; + std::string test_image = argv[2]; + std::string cls_result = argv[3]; + InitAndInfer(model_dir, test_image, cls_result); + return 0; +} diff --git a/tests/timvx/test_ppliteseg.cc b/tests/timvx/test_ppliteseg.cc new file mode 100755 index 0000000000..6858709119 --- /dev/null +++ b/tests/timvx/test_ppliteseg.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common.h" +#include "fastdeploy/vision.h" +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void InitAndInfer(const std::string& model_dir, const std::string& image_file, + const std::string& seg_result_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "deploy.yaml"; + auto subgraph_file = model_dir + sep + "subgraph.txt"; + fastdeploy::vision::EnableFlyCV(); + fastdeploy::RuntimeOption option; + option.UseTimVX(); + option.SetLiteSubgraphPartitionPath(subgraph_file); + + auto model = fastdeploy::vision::segmentation::PaddleSegModel( + model_file, params_file, config_file, option); + + assert(model.Initialized()); + + auto im = cv::imread(image_file); + + fastdeploy::vision::SegmentationResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + // std::cout << res.Str() << std::endl; + // std::ofstream res_str(seg_result_file); + // if(!WriteSegResult(res, seg_result_file)){ + // std::cerr << "Fail to write to " << seg_result_file< + +#include "common.h" +#include "fastdeploy/vision.h" +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void InferAndCompare(const std::string& model_dir, + const std::string& image_file, + const std::string& det_result) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto subgraph_file = model_dir + sep + "subgraph.txt"; + fastdeploy::vision::EnableFlyCV(); + fastdeploy::RuntimeOption option; + option.UseTimVX(); + option.SetLiteSubgraphPartitionPath(subgraph_file); + + auto model = fastdeploy::vision::detection::YOLOv5( + model_file, params_file, option, fastdeploy::ModelFormat::PADDLE); + assert(model.Initialized()); + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + if (CompareDetResult(res, det_result)) { + std::cout << model_dir + " run successfully." << std::endl; + } else { + std::cerr << model_dir + " run failed." << std::endl; + } +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/quant_model " + "path/to/image " + "run_option, " + "e.g ./infer_demo ./yolov5s_quant ./000000014439.jpg " + "yolov5_result.txt" + << std::endl; + return -1; + } + + std::string model_dir = argv[1]; + std::string test_image = argv[2]; + std::string det_result = argv[3]; + InferAndCompare(model_dir, test_image, det_result); + return 0; +}