diff --git a/navix/_version.py b/navix/_version.py index a4beb31..397649b 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.3.8" +__version__ = "0.3.9" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/environments/environment.py b/navix/environments/environment.py index 225965d..eca7f54 100644 --- a/navix/environments/environment.py +++ b/navix/environments/environment.py @@ -20,7 +20,6 @@ from __future__ import annotations -import sys import abc from enum import IntEnum from typing import Any, Callable, Dict @@ -84,10 +83,10 @@ def observation_space(self) -> Space: if self.observation_fn == observations.none: return Continuous(shape=()) elif self.observation_fn == observations.categorical: - return Discrete(sys.maxsize, shape=(self.height, self.width)) + return Discrete(shape=(self.height, self.width)) elif self.observation_fn == observations.categorical_first_person: radius = observations.RADIUS - return Discrete(sys.maxsize, shape=(radius + 1, radius * 2 + 1)) + return Discrete(shape=(radius + 1, radius * 2 + 1)) elif self.observation_fn == observations.rgb: return Discrete( 256, diff --git a/navix/spaces.py b/navix/spaces.py index 66e0011..748908d 100644 --- a/navix/spaces.py +++ b/navix/spaces.py @@ -14,7 +14,6 @@ from __future__ import annotations -from typing import Callable import jax import jax.numpy as jnp @@ -25,13 +24,15 @@ from jax.core import ShapedArray, Shape -POS_INF = jnp.asarray(1e16) -NEG_INF = jnp.asarray(-1e16) +MIN_INT = jax.numpy.iinfo(jnp.int16).min +MAX_INT = jax.numpy.iinfo(jnp.int16).max +MIN_INT_ARR = jnp.asarray(MIN_INT) +MAX_INT_ARR = jnp.asarray(MAX_INT) class Space(ShapedArray): - minimum: Array = NEG_INF - maximum: Array = POS_INF + minimum: Array = MIN_INT_ARR + maximum: Array = MAX_INT_ARR def __repr__(self): return "{}, min={}, max={})".format( @@ -43,19 +44,21 @@ def sample(self, key: KeyArray) -> Array: class Discrete(Space): - def __init__(self, n_elements: int, shape: Shape = (), dtype=jnp.int32): + def __init__(self, n_elements: int = MAX_INT, shape: Shape = (), dtype=jnp.int32): super().__init__(shape, dtype) self.minimum = jnp.asarray(0) self.maximum = jnp.asarray(n_elements - 1) def sample(self, key: KeyArray) -> Array: - return jax.random.randint( - key, self.shape, self.minimum, self.maximum, self.dtype + item = jax.random.randint( + key, self.shape, self.minimum, self.maximum ) + # randint cannot draw jnp.uint, so we cast it later + return jnp.asarray(item, dtype=self.dtype) class Continuous(Space): - def __init__(self, shape: Shape = (), minimum=NEG_INF, maximum=POS_INF): + def __init__(self, shape: Shape = (), minimum=MIN_INT_ARR, maximum=MAX_INT_ARR): super().__init__(shape, jnp.float32) self.minimum = minimum self.maximum = maximum diff --git a/tests/test_spaces.py b/tests/test_spaces.py new file mode 100644 index 0000000..3901259 --- /dev/null +++ b/tests/test_spaces.py @@ -0,0 +1,34 @@ +import jax +import jax.numpy as jnp +from navix.spaces import Space, Continuous, Discrete, MAX_INT, MIN_INT + + +def test_discrete(): + key = jax.random.PRNGKey(42) + elements = (5, 0, MAX_INT, MIN_INT) + shapes = ((), (0,), (0, 0), (1, 2), (5, 5)) + dtypes = (jnp.int8, jnp.int16, jnp.int32) + for element in elements: + for shape in shapes: + for dtype in dtypes: + space = Discrete(element, shape, dtype) + sample = space.sample(key) + print(sample) + assert jnp.all(jnp.logical_not(jnp.isnan(sample))) + + +def test_continuous(): + key = jax.random.PRNGKey(42) + shapes = ((), (0,), (0, 0), (1, 2), (5, 5)) + min_max = [(0.0, 1.0), (0.0, 1), (0, 1), (1.0, -1.0), (MIN_INT, MAX_INT),] + for shape in shapes: + for minimum, maximum in min_max: + space = Continuous(shape=shape, minimum=jnp.asarray(minimum), maximum=jnp.asarray(maximum)) + sample = space.sample(key) + print(sample) + assert jnp.all(jnp.logical_not(jnp.isnan(sample))) + + +if __name__ == '__main__': + test_discrete() + test_continuous() \ No newline at end of file