-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_perception_module.py
executable file
·56 lines (45 loc) · 1.93 KB
/
run_perception_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/python
import rospy
import os
import argparse
from visualization_msgs.msg import Marker
from pose_estimators.run_perception_module import run_detection
from pose_estimators.perception_module import PerceptionModule
from pose_estimators.marker_manager import MarkerManager
from food_detector.spnet_detector import SPNetDetector
from food_detector.retinanet_detector import RetinaNetDetector
from food_detector.spanet_detector import SPANetDetector
import food_detector.ada_feeding_demo_config as conf
if __name__ == '__main__':
parser = argparse.ArgumentParser(
"Run perception module for ada feeding projects")
parser.add_argument(
"--demo-type", choices=['spnet', 'spanet', 'retinanet'],
required=True)
args = parser.parse_args(rospy.myargv()[1:])
rospy.init_node('food_detector')
rospy.init_node('food_detector')
if args.demo_type == 'retinanet':
pose_estimator = RetinaNetDetector(use_cuda=True, node_name=rospy.get_name())
elif args.demo_type == "spnet":
pose_estimator = SPNetDetector(use_cuda=conf.use_cuda, node_name=rospy.get_name())
elif args.demo_type == "spanet":
pose_estimator = SPANetDetector(use_cuda=conf.use_cuda)
else:
raise ValueError("Unknown demo type")
if conf.use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = conf.gpus
print("Using CUDA!")
marker_manager = MarkerManager(
marker_type=Marker.CUBE,
scale=[0.05, 0.01, 0.01],
color=[0.5, 1.0, 0.5, 0.1],
count_items=False) # spnet and spanet handles this internally
perception_module = PerceptionModule(
pose_estimator=pose_estimator,
marker_manager=marker_manager,
detection_frame_marker_topic=None,
detection_frame=conf.camera_tf,
destination_frame=conf.destination_frame,
purge_all_markers_per_update=True)
run_detection(rospy.get_name(), conf.frequency, perception_module)