Skip to content

Commit

Permalink
add edt as part of track class, add script for creating edt
Browse files Browse the repository at this point in the history
  • Loading branch information
hzheng40 committed Jan 18, 2024
1 parent 0f3b630 commit 38697c4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/create_edt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from f110_gym.envs.track import Track
from scipy.ndimage import distance_transform_edt as edt
import argparse
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--track_name", type=str, required=True)
args = parser.parse_args()

print("Loading a map without edt, a warning should appear")
track = Track.from_track_name(args.track_name)
occupancy_map = track.occupancy_map
resolution = track.spec.resolution

dt = resolution * edt(occupancy_map)

# saving
np.save(track.filepath, dt)

print("Loading a map with edt, warning should no longer appear")
track_wedt = Track.from_track_name(args.track_name)
23 changes: 23 additions & 0 deletions gym/f110_gym/envs/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import requests
import yaml
import warnings
from f110_gym.envs.cubic_spline import CubicSpline2D
from PIL import Image
from PIL.Image import Transpose
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
filepath: str,
ext: str,
occupancy_map: np.ndarray,
edt: np.ndarray,
centerline: Raceline = None,
raceline: Raceline = None,
):
Expand Down Expand Up @@ -214,13 +216,25 @@ def from_track_name(track: str):
occupancy_map[occupancy_map <= 128] = 0.0
occupancy_map[occupancy_map > 128] = 255.0

# if exists, load edt
if (track_dir / f"{track}_map.npy").exists():
edt = np.load(track_dir / f"{track}_map.npy")
else:
edt = None
warnings.warn(
f"Track Distance Transform file at {track_dir / f'{track}_map.npy'} not found, will be created before initialization."
)

# if exists, load centerline
if (track_dir / f"{track}_centerline.csv").exists():
centerline = Raceline.from_centerline_file(
track_dir / f"{track}_centerline.csv"
)
else:
centerline = None
warnings.warn(
f"Track Centerline file at {track_dir / f'{track}_centerline.csv'} not found, setting None."
)

# if exists, load raceline
if (track_dir / f"{track}_raceline.csv").exists():
Expand All @@ -229,12 +243,21 @@ def from_track_name(track: str):
)
else:
raceline = centerline
if centerline is None:
warnings.warn(
f"Track Raceline file at {track_dir / f'{track}_raceline.csv'} not found, setting None."
)
else:
warnings.warn(
f"Track Raceline file at {track_dir / f'{track}_raceline.csv'} not found, using Centerline."
)

return Track(
spec=track_spec,
filepath=str((track_dir / map_filename.stem).absolute()),
ext=map_filename.suffix,
occupancy_map=occupancy_map,
edt=edt,
centerline=centerline,
raceline=raceline,
)
Expand Down

0 comments on commit 38697c4

Please sign in to comment.