diff --git a/f1tenth_gym/envs/rendering/rendering_pyqt.py b/f1tenth_gym/envs/rendering/rendering_pyqt.py index d4010c28..0054066d 100644 --- a/f1tenth_gym/envs/rendering/rendering_pyqt.py +++ b/f1tenth_gym/envs/rendering/rendering_pyqt.py @@ -2,6 +2,7 @@ import logging import math from typing import Any, Callable, Optional +import signal import cv2 import numpy as np @@ -52,7 +53,7 @@ def __init__( render_mode : str rendering mode in ["human", "human_fast", "rgb_array"] render_fps : int - number of frames per second + number of frames per second """ super().__init__() self.params = params @@ -71,8 +72,10 @@ def __init__( self.app = QtWidgets.QApplication([]) self.window = pg.GraphicsLayoutWidget() self.window.setWindowTitle("F1Tenth Gym") - self.window.setGeometry(0, 0, self.render_spec.window_size, self.render_spec.window_size) - self.canvas : pg.PlotItem = self.window.addPlot() + self.window.setGeometry( + 0, 0, self.render_spec.window_size, self.render_spec.window_size + ) + self.canvas: pg.PlotItem = self.window.addPlot() # Disable interactivity self.canvas.setMouseEnabled(x=False, y=False) # Disable mouse panning & zooming @@ -84,33 +87,29 @@ def __init__( legend.mouseDragEvent = lambda *args, **kwargs: None legend.hoverEvent = lambda *args, **kwargs: None # self.scene() is a pyqtgraph.GraphicsScene.GraphicsScene.GraphicsScene - self.window.scene().sigMouseClicked.connect(self.mouse_clicked) + self.window.scene().sigMouseClicked.connect(self.mouse_clicked) self.window.keyPressEvent = self.key_pressed # Remove axes - self.canvas.hideAxis('bottom') - self.canvas.hideAxis('left') + self.canvas.hideAxis("bottom") + self.canvas.hideAxis("left") - # setting plot window background color to yellow - self.window.setBackground('w') + # setting plot window background color to yellow + self.window.setBackground("w") # fps and time renderer self.clock = FrameCounter() - self.fps_renderer = TextObject( - parent=self.canvas, position="bottom_left" - ) - self.time_renderer = TextObject( - parent=self.canvas, position="bottom_right" - ) + self.fps_renderer = TextObject(parent=self.canvas, position="bottom_left") + self.time_renderer = TextObject(parent=self.canvas, position="bottom_right") self.bottom_info_renderer = TextObject( parent=self.canvas, position="bottom_center" ) - self.top_info_renderer = TextObject( - parent=self.canvas, position="top_center" - ) + self.top_info_renderer = TextObject(parent=self.canvas, position="top_center") if self.render_mode in ["human", "human_fast"]: - self.clock.sigFpsUpdate.connect(lambda fps: self.fps_renderer.render(f'FPS: {fps:.1f}')) + self.clock.sigFpsUpdate.connect( + lambda fps: self.fps_renderer.render(f"FPS: {fps:.1f}") + ) colors_rgb = [ [rgb for rgb in ImageColor.getcolor(c, "RGB")] @@ -139,7 +138,7 @@ def __init__( self.image_item = pg.ImageItem(track_map) # Example: Transformed display of ImageItem tr = QtGui.QTransform() # prepare ImageItem transformation: - # Translate image by the origin of the map + # Translate image by the origin of the map tr.translate(self.map_origin[0], self.map_origin[1]) # Scale image by the resolution of the map tr.scale(self.map_resolution, self.map_resolution) @@ -160,6 +159,8 @@ def __init__( self.follow_agent_flag: bool = False self.agent_to_follow: int = None + # Allow a KeyboardInterrupt to kill the window without segfaulting + signal.signal(signal.SIGINT, signal.SIG_DFL) self.window.show() def update(self, state: dict) -> None: @@ -202,7 +203,7 @@ def add_renderer_callback(self, callback_fn: Callable[[EnvRenderer], None]) -> N callback function to be called at every rendering step """ self.callbacks.append(callback_fn) - + def key_pressed(self, event: QtGui.QKeyEvent) -> None: """ Handle key press events. @@ -233,9 +234,7 @@ def mouse_clicked(self, event: QtGui.QMouseEvent) -> None: if self.agent_to_follow is None: self.agent_to_follow = 0 else: - self.agent_to_follow = (self.agent_to_follow + 1) % len( - self.agent_ids - ) + self.agent_to_follow = (self.agent_to_follow + 1) % len(self.agent_ids) self.active_map_renderer = "car" elif event.button() == QtCore.Qt.MouseButton.RightButton: @@ -245,9 +244,7 @@ def mouse_clicked(self, event: QtGui.QMouseEvent) -> None: if self.agent_to_follow is None: self.agent_to_follow = 0 else: - self.agent_to_follow = (self.agent_to_follow - 1) % len( - self.agent_ids - ) + self.agent_to_follow = (self.agent_to_follow - 1) % len(self.agent_ids) self.active_map_renderer = "car" elif event.button() == QtCore.Qt.MouseButton.MiddleButton: @@ -257,7 +254,7 @@ def mouse_clicked(self, event: QtGui.QMouseEvent) -> None: self.agent_to_follow = None self.active_map_renderer = "map" - + def render(self) -> Optional[np.ndarray]: """ Render the current state in a frame. @@ -269,7 +266,6 @@ def render(self) -> Optional[np.ndarray]: if render_mode is "rgb_array", returns the rendered frame as an array """ if self.draw_flag: - # draw cars for i in range(len(self.agent_ids)): self.cars[i].render() @@ -284,15 +280,13 @@ def render(self) -> Optional[np.ndarray]: self.canvas.setYRange(ego_y - 10, ego_y + 10) else: self.canvas.autoRange() - + agent_to_follow_id = ( self.agent_ids[self.agent_to_follow] if self.agent_to_follow is not None else None ) - self.bottom_info_renderer.render( - text=f"Focus on: {agent_to_follow_id}" - ) + self.bottom_info_renderer.render(text=f"Focus on: {agent_to_follow_id}") if self.render_spec.show_info: self.top_info_renderer.render(text=INSTRUCTION_TEXT) @@ -303,7 +297,7 @@ def render(self) -> Optional[np.ndarray]: if self.render_mode in ["human", "human_fast"]: assert self.window is not None - else: + else: # rgb_array # TODO: extract the frame from the canvas frame = None @@ -328,7 +322,13 @@ def render_points( size of the points in pixels, by default 1 """ return self.canvas.plot( - points[:, 0], points[:, 1], pen=None, symbol="o", symbolPen=pg.mkPen(color=color, width=0), symbolBrush=pg.mkBrush(color=color, width=0), symbolSize=size + points[:, 0], + points[:, 1], + pen=None, + symbol="o", + symbolPen=pg.mkPen(color=color, width=0), + symbolBrush=pg.mkBrush(color=color, width=0), + symbolSize=size, ) def render_lines( @@ -374,7 +374,7 @@ def render_closed_lines( """ # Append the first point to the end to close the loop points = np.vstack([points, points[0]]) - + pen = pg.mkPen(color=pg.mkColor(*color), width=size) pen.setCapStyle(pg.QtCore.Qt.PenCapStyle.RoundCap) pen.setJoinStyle(pg.QtCore.Qt.PenJoinStyle.RoundJoin) @@ -383,7 +383,6 @@ def render_closed_lines( points[:, 0], points[:, 1], pen=pen, cosmetic=True, antialias=True ) ## setting pen=None disables line drawing - def close(self) -> None: """ Close the rendering environment.