Skip to content

Commit f59deb6

Browse files
authored
Panoptic segmentation: download pretrained weights in ROS node (opendr-eu#269)
* Download pre-trained model weights * Add backward compatibility * Do not remove downloaded checkpoint file * Resolve PEP8 issue * Consistent with ROS2 * Address review * Reduce queue size to 1 * Set default for input rgb image topic
1 parent f71fb87 commit f59deb6

File tree

4 files changed

+82
-43
lines changed

4 files changed

+82
-43
lines changed

projects/opendr_ws/src/perception/README.md

+7-6
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,16 @@ rosrun perception object_detection_2d_gem.py
164164
A ROS node for performing panoptic segmentation on a specified RGB image stream using the [EfficientPS](../../../../src/opendr/perception/panoptic_segmentation/README.md) network.
165165
Assuming that the OpenDR catkin workspace has been sourced, the node can be started with:
166166
```shell
167-
rosrun perception panoptic_segmentation_efficient_ps.py CHECKPOINT IMAGE_TOPIC
167+
rosrun perception panoptic_segmentation_efficient_ps.py
168168
```
169-
with `CHECKPOINT` pointing to the path to the trained model weights and `IMAGE_TOPIC` specifying the ROS topic, to which the node will subscribe.
170169

171-
Additionally, the following optional arguments are available:
170+
The following optional arguments are available:
172171
- `-h, --help`: show a help message and exit
173-
- `--heamap_topic HEATMAP_TOPIC`: publish the semantic and instance maps on `HEATMAP_TOPIC`
174-
- `--visualization_topic VISUALIZATION_TOPIC`: publish the panoptic segmentation map as an RGB image on `VISUALIZATION_TOPIC` or a more detailed overview if using the `--detailed_visualization` flag
175-
- `--detailed_visualization`: generate a combined overview of the input RGB image and the semantic, instance, and panoptic segmentation maps
172+
- `--input_rgb_image_topic INPUT_RGB_IMAGE_TOPIC` : listen to RGB images on this topic (default=`/usb_cam/image_raw`)
173+
- `--checkpoint CHECKPOINT` : download pretrained models [cityscapes, kitti] or load from the provided path (default=`cityscapes`)
174+
- `--output_rgb_image_topic OUTPUT_RGB_IMAGE_TOPIC`: publish the semantic and instance maps on this topic as `OUTPUT_HEATMAP_TOPIC/semantic` and `OUTPUT_HEATMAP_TOPIC/instance` (default=`/opendir/panoptic`)
175+
- `--visualization_topic VISUALIZATION_TOPIC`: publish the panoptic segmentation map as an RGB image on `VISUALIZATION_TOPIC` or a more detailed overview if using the `--detailed_visualization` flag (default=`/opendr/panoptic/rgb_visualization`)
176+
- `--detailed_visualization`: generate a combined overview of the input RGB image and the semantic, instance, and panoptic segmentation maps and publish it on `OUTPUT_RGB_IMAGE_TOPIC` (default=deactivated)
176177

177178

178179
## Semantic Segmentation ROS Node

projects/opendr_ws/src/perception/scripts/panoptic_segmentation_efficient_ps.py

+55-30
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import sys
17+
from pathlib import Path
1618
import argparse
1719
from typing import Optional
1820

@@ -29,27 +31,31 @@
2931

3032
class EfficientPsNode:
3133
def __init__(self,
34+
input_rgb_image_topic: str,
3235
checkpoint: str,
33-
input_image_topic: str,
3436
output_heatmap_topic: Optional[str] = None,
35-
output_visualization_topic: Optional[str] = None,
37+
output_rgb_visualization_topic: Optional[str] = None,
3638
detailed_visualization: bool = False
3739
):
3840
"""
3941
Initialize the EfficientPS ROS node and create an instance of the respective learner class.
40-
:param checkpoint: Path to a saved model
42+
:param checkpoint: This is either a path to a saved model or one of [cityscapes, kitti] to download
43+
pre-trained model weights.
4144
:type checkpoint: str
42-
:param input_image_topic: ROS topic for the input image stream
43-
:type input_image_topic: str
45+
:param input_rgb_image_topic: ROS topic for the input image stream
46+
:type input_rgb_image_topic: str
4447
:param output_heatmap_topic: ROS topic for the predicted semantic and instance maps
4548
:type output_heatmap_topic: str
46-
:param output_visualization_topic: ROS topic for the generated visualization of the panoptic map
47-
:type output_visualization_topic: str
49+
:param output_rgb_visualization_topic: ROS topic for the generated visualization of the panoptic map
50+
:type output_rgb_visualization_topic: str
51+
:param detailed_visualization: if True, generate a combined overview of the input RGB image and the
52+
semantic, instance, and panoptic segmentation maps and publish it on output_rgb_visualization_topic
53+
:type detailed_visualization: bool
4854
"""
55+
self.input_rgb_image_topic = input_rgb_image_topic
4956
self.checkpoint = checkpoint
50-
self.input_image_topic = input_image_topic
5157
self.output_heatmap_topic = output_heatmap_topic
52-
self.output_visualization_topic = output_visualization_topic
58+
self.output_rgb_visualization_topic = output_rgb_visualization_topic
5359
self.detailed_visualization = detailed_visualization
5460

5561
# Initialize all ROS related things
@@ -59,14 +65,27 @@ def __init__(self,
5965
self._visualization_publisher = None
6066

6167
# Initialize the panoptic segmentation network
62-
self._learner = EfficientPsLearner()
68+
config_file = Path(sys.modules[
69+
EfficientPsLearner.__module__].__file__).parent / 'configs' / 'singlegpu_cityscapes.py'
70+
self._learner = EfficientPsLearner(str(config_file))
71+
72+
# Other
73+
self._tmp_folder = Path(__file__).parent.parent / 'tmp' / 'efficientps'
74+
self._tmp_folder.mkdir(exist_ok=True, parents=True)
6375

6476
def _init_learner(self) -> bool:
6577
"""
66-
Load the weights from the specified checkpoint file.
78+
The model can be initialized via
79+
1. downloading pre-trained weights for Cityscapes or KITTI.
80+
2. passing a path to an existing checkpoint file.
6781
6882
This has not been done in the __init__() function since logging is available only once the node is registered.
6983
"""
84+
if self.checkpoint in ['cityscapes', 'kitti']:
85+
file_path = EfficientPsLearner.download(str(self._tmp_folder),
86+
trained_on=self.checkpoint)
87+
self.checkpoint = file_path
88+
7089
if self._learner.load(self.checkpoint):
7190
rospy.loginfo('Successfully loaded the checkpoint.')
7291
return True
@@ -78,19 +97,20 @@ def _init_subscribers(self):
7897
"""
7998
Subscribe to all relevant topics.
8099
"""
81-
rospy.Subscriber(self.input_image_topic, ROS_Image, self.callback)
100+
rospy.Subscriber(self.input_rgb_image_topic, ROS_Image, self.callback, queue_size=1, buff_size=10000000)
82101

83102
def _init_publisher(self):
84103
"""
85104
Set up the publishers as requested by the user.
86105
"""
87106
if self.output_heatmap_topic is not None:
88-
self._instance_heatmap_publisher = rospy.Publisher(f'{self.output_heatmap_topic}/instance', ROS_Image,
89-
queue_size=10)
90-
self._semantic_heatmap_publisher = rospy.Publisher(f'{self.output_heatmap_topic}/semantic', ROS_Image,
91-
queue_size=10)
92-
if self.output_visualization_topic is not None:
93-
self._visualization_publisher = rospy.Publisher(self.output_visualization_topic, ROS_Image, queue_size=10)
107+
self._instance_heatmap_publisher = rospy.Publisher(
108+
f'{self.output_heatmap_topic}/instance', ROS_Image, queue_size=10)
109+
self._semantic_heatmap_publisher = rospy.Publisher(
110+
f'{self.output_heatmap_topic}/semantic', ROS_Image, queue_size=10)
111+
if self.output_rgb_visualization_topic is not None:
112+
self._visualization_publisher = rospy.Publisher(self.output_rgb_visualization_topic,
113+
ROS_Image, queue_size=10)
94114

95115
def listen(self):
96116
"""
@@ -128,26 +148,31 @@ def callback(self, data: ROS_Image):
128148
if self._semantic_heatmap_publisher is not None and self._semantic_heatmap_publisher.get_num_connections() > 0:
129149
self._semantic_heatmap_publisher.publish(self._bridge.to_ros_image(prediction[1]))
130150

131-
except Exception:
132-
rospy.logwarn('Failed to generate prediction.')
151+
except Exception as e:
152+
rospy.logwarn(f'Failed to generate prediction: {e}')
133153

134154

135155
if __name__ == '__main__':
136-
parser = argparse.ArgumentParser()
137-
parser.add_argument('checkpoint', type=str, help='load the model weights from the provided path')
138-
parser.add_argument('image_topic', type=str, help='listen to images on this topic')
139-
parser.add_argument('--heatmap_topic', type=str, help='publish the semantic and instance maps on this topic')
140-
parser.add_argument('--visualization_topic', type=str,
156+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
157+
parser.add_argument('input_rgb_image_topic', type=str, default='/usb_cam/image_raw',
158+
help='listen to RGB images on this topic')
159+
parser.add_argument('--checkpoint', type=str, default='cityscapes',
160+
help='download pretrained models [cityscapes, kitti] or load from the provided path')
161+
parser.add_argument('--output_heatmap_topic', type=str, default='/opendr/panoptic',
162+
help='publish the semantic and instance maps on this topic as "OUTPUT_HEATMAP_TOPIC/semantic" \
163+
and "OUTPUT_HEATMAP_TOPIC/instance"')
164+
parser.add_argument('--output_rgb_image_topic', type=str,
165+
default='/opendr/panoptic/rgb_visualization',
141166
help='publish the panoptic segmentation map as an RGB image on this topic or a more detailed \
142167
overview if using the --detailed_visualization flag')
143168
parser.add_argument('--detailed_visualization', action='store_true',
144169
help='generate a combined overview of the input RGB image and the semantic, instance, and \
145-
panoptic segmentation maps')
170+
panoptic segmentation maps and publish it on OUTPUT_RGB_IMAGE_TOPIC')
146171
args = parser.parse_args()
147172

148-
efficient_ps_node = EfficientPsNode(args.checkpoint,
149-
args.image_topic,
150-
args.heatmap_topic,
151-
args.visualization_topic,
173+
efficient_ps_node = EfficientPsNode(args.input_rgb_image_topic,
174+
args.checkpoint,
175+
args.output_heatmap_topic,
176+
args.output_rgb_image_topic,
152177
args.detailed_visualization)
153178
efficient_ps_node.listen()

src/opendr/engine/target.py

+8
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,14 @@ def numpy(self):
10721072
# Since this class stores the data as NumPy arrays, we can directly return the data.
10731073
return self.data
10741074

1075+
def opencv(self):
1076+
"""
1077+
Required to support the ros bridge for images.
1078+
:return: a NumPy-compatible representation of data
1079+
:rtype: numpy.ndarray
1080+
"""
1081+
return self.numpy()
1082+
10751083
def shape(self) -> Tuple[int, ...]:
10761084
"""
10771085
Returns the shape of the underlying NumPy array.

src/opendr/perception/panoptic_segmentation/efficient_ps/efficient_ps_learner.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,18 @@ def infer(self,
306306
warnings.warn('The current model has not been trained.')
307307
self.model.eval()
308308

309-
# Build the data pipeline
310-
test_pipeline = Compose(self._cfg.test_pipeline[1:])
311-
device = next(self.model.parameters()).device
312-
313-
# Convert to the format expected by the mmdetection API
314309
single_image_mode = False
315310
if isinstance(batch, Image):
316311
batch = [batch]
317312
single_image_mode = True
313+
314+
# Convert to the format expected by the mmdetection API
318315
mmdet_batch = []
316+
device = next(self.model.parameters()).device
319317
for img in batch:
318+
# Change the processing size according to the input image
319+
self._cfg.test_pipeline[1:][0]['img_scale'] = batch[0].data.shape[1:]
320+
test_pipeline = Compose(self._cfg.test_pipeline[1:])
320321
# Convert from OpenDR convention (CHW/RGB) to the expected format (HWC/BGR)
321322
img_ = img.convert('channels_last', 'bgr')
322323
mmdet_img = {'filename': None, 'img': img_, 'img_shape': img_.shape, 'ori_shape': img_.shape}
@@ -481,8 +482,12 @@ def update_to(b=1, bsize=1, total=None):
481482

482483
return update_to
483484

484-
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=f'Downloading {filename}') as pbar:
485-
urllib.request.urlretrieve(url, filename, pbar_hook(pbar))
485+
if os.path.exists(filename) and os.path.isfile(filename):
486+
print(f'File already downloaded: {filename}')
487+
else:
488+
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=f'Downloading {filename}') \
489+
as pbar:
490+
urllib.request.urlretrieve(url, filename, pbar_hook(pbar))
486491
return filename
487492

488493
@staticmethod

0 commit comments

Comments
 (0)