diff --git a/examples/create_edt.py b/examples/create_edt.py new file mode 100644 index 00000000..ca951a2b --- /dev/null +++ b/examples/create_edt.py @@ -0,0 +1,21 @@ +from f110_gym.envs.track import Track +from scipy.ndimage import distance_transform_edt as edt +import argparse +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument("--track_name", type=str, required=True) +args = parser.parse_args() + +print("Loading a map without edt, a warning should appear") +track = Track.from_track_name(args.track_name) +occupancy_map = track.occupancy_map +resolution = track.spec.resolution + +dt = resolution * edt(occupancy_map) + +# saving +np.save(track.filepath, dt) + +print("Loading a map with edt, warning should no longer appear") +track_wedt = Track.from_track_name(args.track_name) \ No newline at end of file diff --git a/gym/f110_gym/envs/track.py b/gym/f110_gym/envs/track.py index 541994bf..ec3edc5c 100644 --- a/gym/f110_gym/envs/track.py +++ b/gym/f110_gym/envs/track.py @@ -7,6 +7,7 @@ import numpy as np import requests import yaml +import warnings from f110_gym.envs.cubic_spline import CubicSpline2D from PIL import Image from PIL.Image import Transpose @@ -174,6 +175,7 @@ def __init__( filepath: str, ext: str, occupancy_map: np.ndarray, + edt: np.ndarray, centerline: Raceline = None, raceline: Raceline = None, ): @@ -214,6 +216,15 @@ def from_track_name(track: str): occupancy_map[occupancy_map <= 128] = 0.0 occupancy_map[occupancy_map > 128] = 255.0 + # if exists, load edt + if (track_dir / f"{track}_map.npy").exists(): + edt = np.load(track_dir / f"{track}_map.npy") + else: + edt = None + warnings.warn( + f"Track Distance Transform file at {track_dir / f'{track}_map.npy'} not found, will be created before initialization." + ) + # if exists, load centerline if (track_dir / f"{track}_centerline.csv").exists(): centerline = Raceline.from_centerline_file( @@ -221,6 +232,9 @@ def from_track_name(track: str): ) else: centerline = None + warnings.warn( + f"Track Centerline file at {track_dir / f'{track}_centerline.csv'} not found, setting None." + ) # if exists, load raceline if (track_dir / f"{track}_raceline.csv").exists(): @@ -229,12 +243,21 @@ def from_track_name(track: str): ) else: raceline = centerline + if centerline is None: + warnings.warn( + f"Track Raceline file at {track_dir / f'{track}_raceline.csv'} not found, setting None." + ) + else: + warnings.warn( + f"Track Raceline file at {track_dir / f'{track}_raceline.csv'} not found, using Centerline." + ) return Track( spec=track_spec, filepath=str((track_dir / map_filename.stem).absolute()), ext=map_filename.suffix, occupancy_map=occupancy_map, + edt=edt, centerline=centerline, raceline=raceline, )