Skip to content

Commit

Permalink
test.py add cpu forward
Browse files Browse the repository at this point in the history
  • Loading branch information
duoduo committed Jul 2, 2022
1 parent e64a02c commit df3630d
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from utils.tool import *
from module.detector import Detector

# 指定后端设备CUDA&CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == '__main__':
# 指定训练配置文件
parser = argparse.ArgumentParser()
Expand All @@ -18,12 +15,25 @@
parser.add_argument('--img', type=str, default='', help='The path of test image')
parser.add_argument('--thresh', type=float, default=0.8, help='The path of test image')
parser.add_argument('--onnx', action="store_true", default=False, help='Export onnx file')
parser.add_argument('--cpu', action="store_true", default=False, help='Run on cpu')

opt = parser.parse_args()
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
assert os.path.exists(opt.weight), "请指定正确的模型路径"
assert os.path.exists(opt.img), "请指定正确的测试图像路径"

# 选择推理后端
if opt.cpu:
print("run on cpu...")
device = torch.device("cpu")
else:
if torch.cuda.is_available():
print("run on gpu...")
device = torch.device("cuda")
else:
print("run on cpu...")
device = torch.device("cpu")

# 解析yaml配置文件
cfg = LoadYaml(opt.yaml)
print(cfg)
Expand Down Expand Up @@ -85,4 +95,4 @@
cv2.putText(ori_img, '%.2f' % obj_score, (x1, y1 - 5), 0, 0.7, (0, 255, 0), 2)
cv2.putText(ori_img, category, (x1, y1 - 25), 0, 0.7, (0, 255, 0), 2)

cv2.imwrite("result.png", ori_img)
cv2.imwrite("result.png", ori_img)

0 comments on commit df3630d

Please sign in to comment.