Skip to content

Commit

Permalink
starting jax dev
Browse files Browse the repository at this point in the history
  • Loading branch information
hzheng40 committed Jan 9, 2024
1 parent 0f3b630 commit 7bc7ca2
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gym/f110_gym/envs/laser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
from numba import njit
from scipy.ndimage import distance_transform_edt as edt

from f110_gym.envs.utils import edt as jedt

import jax
import jax.numpy as jnp


def get_dt(bitmap, resolution):
"""
Expand All @@ -50,6 +55,33 @@ def get_dt(bitmap, resolution):
return dt


@jax.jit
def get_dt_jax(bitmap: jax.Array, resolution: float) -> jax.Array:
"""
Exact Eucliden Distance transformation, returns the distance matrix from the input bitmap.
Follows scipy.ndimage.distance_transform_edt
n
y_i = sqrt(sum (x[i]-b[i])**2)
i
where b[i] is the background point (value 0) with the smallest Euclidean distance to input points x[i]
n is number of dimensions, default n=2 here.
Args:
bitmap (jax.Array, (n, m)): input binary bitmap of the environment, where 0 is obstacles, and 255 (or anything > 0) is freespace
resolution (float): resolution of the input bitmap (m/cell)
Returns:
dt (jax.Array, (n, m)): output distance matrix, where each cell has the corresponding distance (in meters) to the closest obstacle
"""
# convert to binary
input = jnp.atleast_1d(jnp.where(bitmap, 1, 0).astype(jnp.int8))
# features
pass



@njit(cache=True)
def xy_2_rc(x, y, orig_x, orig_y, orig_c, orig_s, height, width, resolution):
"""
Expand Down
76 changes: 76 additions & 0 deletions gym/f110_gym/envs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial

# types
from typing import Any, Dict, TypeVar

Expand All @@ -23,3 +28,74 @@ def deep_update(
else:
updated_mapping[k] = v
return updated_mapping


def distance_transform_1d(vector):
sorted_indices = np.argsort(np.abs(vector))
closest_index = sorted_indices[0]
closest_value = vector[closest_index]

return closest_value


@jax.jit
def edt(image):
"""
8 neighborhood exact euclidean distance transform
"""

@jax.jit
def replace(v, i, val, inc):
v = v.at[i].set(val)
inc += 2
return v, inc

@jax.jit
def dont_replace(v, i, val, inc):
inc = 1
return v, inc



h, w = image.shape
max_d = h**2 + w**2
dt = jnp.where(image, max_d, 0)

# vertical pass
def dt_1d(vec):
def pass_fn(vec):
return vec
def calc_dist(vec):

return dist

dt = jax.lax.cond(
jnp.count_nonzero(vec) == len(vec),
pass_fn,
calc_dist,
vec,
)
return dt

dt = jax.vmap(dt_1d, in_axes=[1, ])(dt)

# horizontal pass
for i in range(h):
# copy row of vertical distances
dtc = dt[i, :]
# column positions
for j in range(w):
# init min dist
dist_min = dtc[j]
# comparee with column position
for k in range(w):
# combine vertical and horizontal components
d = dtc[k] + (k - j) ** 2

dist_min = jnp.select([dist_min > d, dist_min <= d], [d, dist_min])
# if dist_min > d:
# # new minimum
# dist_min = d
dt.at[i, j].set(dist_min)

return jnp.sqrt(dt)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"requests",
"shapely",
"opencv-python",
"jax[cpu]",
],
extras_require={
"dev": [
Expand Down
28 changes: 28 additions & 0 deletions test_jedt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from f110_gym.envs.utils import edt as jedt
from scipy.ndimage import distance_transform_edt as edt
from f110_gym.envs.track import Track
import numpy as np
import jax.numpy as jnp

track = Track.from_track_name("Spielberg")
map_img_real = track.occupancy_map

map_img = 255 * np.ones((5, 5))
map_img[2, 3] = 0

dt = edt(map_img)
jdt = jedt(jnp.array(map_img))

print(map_img)
print('------')
print(dt)
print('------')
print(jdt)
assert np.allclose(dt, jdt)
print(f"Transforms equal: {np.allclose(dt, jdt)}.")

# benchmark
import timeit

print(timeit.timeit("edt(map_img)", globals=globals()))
print(timeit.timeit("jedt(jnp.array(map_img))", globals=globals()))

0 comments on commit 7bc7ca2

Please sign in to comment.