Skip to content

Commit

Permalink
Merge pull request #44 from epignatelli/fix/overflow
Browse files Browse the repository at this point in the history
Fix numerical overflow of jax array from `sys.maxsize`
  • Loading branch information
epignatelli authored Jul 19, 2023
2 parents 9b832b6 + 8c11f73 commit b838db7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.3.8"
__version__ = "0.3.9"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
5 changes: 2 additions & 3 deletions navix/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from __future__ import annotations

import sys
import abc
from enum import IntEnum
from typing import Any, Callable, Dict
Expand Down Expand Up @@ -84,10 +83,10 @@ def observation_space(self) -> Space:
if self.observation_fn == observations.none:
return Continuous(shape=())
elif self.observation_fn == observations.categorical:
return Discrete(sys.maxsize, shape=(self.height, self.width))
return Discrete(shape=(self.height, self.width))
elif self.observation_fn == observations.categorical_first_person:
radius = observations.RADIUS
return Discrete(sys.maxsize, shape=(radius + 1, radius * 2 + 1))
return Discrete(shape=(radius + 1, radius * 2 + 1))
elif self.observation_fn == observations.rgb:
return Discrete(
256,
Expand Down
21 changes: 12 additions & 9 deletions navix/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from __future__ import annotations
from typing import Callable

import jax
import jax.numpy as jnp
Expand All @@ -25,13 +24,15 @@
from jax.core import ShapedArray, Shape


POS_INF = jnp.asarray(1e16)
NEG_INF = jnp.asarray(-1e16)
MIN_INT = jax.numpy.iinfo(jnp.int16).min
MAX_INT = jax.numpy.iinfo(jnp.int16).max
MIN_INT_ARR = jnp.asarray(MIN_INT)
MAX_INT_ARR = jnp.asarray(MAX_INT)


class Space(ShapedArray):
minimum: Array = NEG_INF
maximum: Array = POS_INF
minimum: Array = MIN_INT_ARR
maximum: Array = MAX_INT_ARR

def __repr__(self):
return "{}, min={}, max={})".format(
Expand All @@ -43,19 +44,21 @@ def sample(self, key: KeyArray) -> Array:


class Discrete(Space):
def __init__(self, n_elements: int, shape: Shape = (), dtype=jnp.int32):
def __init__(self, n_elements: int = MAX_INT, shape: Shape = (), dtype=jnp.int32):
super().__init__(shape, dtype)
self.minimum = jnp.asarray(0)
self.maximum = jnp.asarray(n_elements - 1)

def sample(self, key: KeyArray) -> Array:
return jax.random.randint(
key, self.shape, self.minimum, self.maximum, self.dtype
item = jax.random.randint(
key, self.shape, self.minimum, self.maximum
)
# randint cannot draw jnp.uint, so we cast it later
return jnp.asarray(item, dtype=self.dtype)


class Continuous(Space):
def __init__(self, shape: Shape = (), minimum=NEG_INF, maximum=POS_INF):
def __init__(self, shape: Shape = (), minimum=MIN_INT_ARR, maximum=MAX_INT_ARR):
super().__init__(shape, jnp.float32)
self.minimum = minimum
self.maximum = maximum
Expand Down
34 changes: 34 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import jax.numpy as jnp
from navix.spaces import Space, Continuous, Discrete, MAX_INT, MIN_INT


def test_discrete():
key = jax.random.PRNGKey(42)
elements = (5, 0, MAX_INT, MIN_INT)
shapes = ((), (0,), (0, 0), (1, 2), (5, 5))
dtypes = (jnp.int8, jnp.int16, jnp.int32)
for element in elements:
for shape in shapes:
for dtype in dtypes:
space = Discrete(element, shape, dtype)
sample = space.sample(key)
print(sample)
assert jnp.all(jnp.logical_not(jnp.isnan(sample)))


def test_continuous():
key = jax.random.PRNGKey(42)
shapes = ((), (0,), (0, 0), (1, 2), (5, 5))
min_max = [(0.0, 1.0), (0.0, 1), (0, 1), (1.0, -1.0), (MIN_INT, MAX_INT),]
for shape in shapes:
for minimum, maximum in min_max:
space = Continuous(shape=shape, minimum=jnp.asarray(minimum), maximum=jnp.asarray(maximum))
sample = space.sample(key)
print(sample)
assert jnp.all(jnp.logical_not(jnp.isnan(sample)))


if __name__ == '__main__':
test_discrete()
test_continuous()

0 comments on commit b838db7

Please sign in to comment.