Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] MJX environment prototype (WIP) #834

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
cf153f6
Add Hopper and Walker2D models for v5
Kallinteris-Andreas May 2, 2023
bc92449
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas May 9, 2023
0cbdd72
Delete hopper_v5.xml
Kallinteris-Andreas May 9, 2023
db3734e
Delete walker2d_v5.xml
Kallinteris-Andreas May 9, 2023
a2d2e64
General MuJoCo Env Documention Cleanup
Kallinteris-Andreas May 9, 2023
f58bb5e
typofix
Kallinteris-Andreas May 9, 2023
7a4bc32
typo fix
Kallinteris-Andreas May 9, 2023
2418631
update following @pseudo-rnd-thoughts reviews
Kallinteris-Andreas May 9, 2023
3b9080b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 5, 2023
77bcb8b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 16, 2023
7639d18
refactor `tests/env/test_mojoco.py` ->
Kallinteris-Andreas Jun 16, 2023
8eb1b11
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 27, 2023
61d0848
Update setup.py
Kallinteris-Andreas Oct 23, 2023
5831a19
do nothing
Kallinteris-Andreas Oct 23, 2023
803dc49
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 3, 2023
d99cc5d
[MuJoCo] add action space figures
Kallinteris-Andreas Nov 3, 2023
f788bb3
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 10, 2023
450b471
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Nov 30, 2023
14fb4d8
replace `flat.copy()` with `flatten()`
Kallinteris-Andreas Dec 5, 2023
1583839
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 5, 2023
47a7059
add `MuJoCo.test_model_sensors`
Kallinteris-Andreas Dec 6, 2023
9dc31e2
`test_model_sensors` remove check for standup `v3`
Kallinteris-Andreas Dec 6, 2023
bededa3
factorize `_get_rew()` out of `step`
Kallinteris-Andreas Dec 6, 2023
999d888
some cleanup
Kallinteris-Andreas Dec 6, 2023
0f59baa
support `python==3.8`
Kallinteris-Andreas Dec 6, 2023
76f5e17
fix for real this time
Kallinteris-Andreas Dec 6, 2023
724e47f
`black`
Kallinteris-Andreas Dec 6, 2023
32c1cb8
add prototype
Kallinteris-Andreas Dec 10, 2023
e1772bc
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 10, 2023
30cc231
cleanup
Kallinteris-Andreas Dec 15, 2023
925fcdc
update mjx envs
Kallinteris-Andreas Feb 1, 2024
b7f8806
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 1, 2024
08299e7
huge update
Kallinteris-Andreas Feb 5, 2024
be527c4
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 5, 2024
04ed837
`pre-commit`
Kallinteris-Andreas Feb 5, 2024
72d87ba
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 15, 2024
696b0a0
update
Kallinteris-Andreas Feb 15, 2024
a7c614a
`pre-commit`
Kallinteris-Andreas Feb 15, 2024
3e56f40
fix reacher
Kallinteris-Andreas Feb 22, 2024
9fa9225
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Oct 12, 2024
a5b9bba
`pre-commit`
Kallinteris-Andreas Oct 12, 2024
4e28d8a
update func_env
Kallinteris-Andreas Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions gymnasium/envs/mujoco/f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
from os import path

import numpy as np

import gymnasium
from gymnasium.envs.mujoco import MujocoRenderer


try:
import jax
import mujoco
from jax import numpy as jnp
from mujoco import mjx
except ImportError as e:
MJX_IMPORT_ERROR = e
else:
MJX_IMPORT_ERROR = None

DEFAULT_CAMERA_CONFIG = { # TODO reuse the one from v5
"distance": 4.0,
}


class MJXEnv(
gymnasium.functional.FuncEnv[
mjx._src.types.Data, jnp.ndarray, jnp.ndarray, jnp.ndarray, bool, MujocoRenderer
]
):
"""The Base for MJX Environments"""

def __init__(self, model_path, frame_skip):
if MJX_IMPORT_ERROR is not None:
raise gymnasium.error.DependencyNotInstalled(
f"{MJX_IMPORT_ERROR}. "
"(HINT: you need to install mujoco, run `pip install gymnasium[mjx]`.)" # TODO actually create gymnasium[mjx]
)

# NOTE can not be JITted because of `Box` not support jax.numpy
if model_path.startswith(".") or model_path.startswith("/"): # TODO cleanup
self.fullpath = model_path
elif model_path.startswith("~"):
self.fullpath = path.expanduser(model_path)
else:
self.fullpath = path.join(path.dirname(__file__), "assets", model_path)
if not path.exists(self.fullpath):
raise OSError(f"File {self.fullpath} does not exist")

self.frame_skip = frame_skip

self.model = mujoco.MjModel.from_xml_path(
self.fullpath
) # TODO? do not store and replace with mjx.get_model with mjx==3.1
# NOTE too much state?
# alternatives state implementions
# 1. functional_state = (mjx_data, mjx_model), least internal state in MJXenv, most state in functional_state
# 2. functional_state = [qpos,qvel], most internal state in MJXenv, least state in functional_state
self.mjx_model = mjx.device_put(self.model)

# set action space
low_action_bound, high_action_bound = self.mjx_model.actuator_ctrlrange.T
# TODO change bounds and types when and if `Box` supports JAX nativly
self.action_space = gymnasium.spaces.Box(
low=np.array(low_action_bound),
high=np.array(high_action_bound),
dtype=np.float32,
)
# self.action_space = gymnasium.spaces.Box(low=low_action_bound, high=high_action_bound, dtype=low_action_bound.dtype)
# observation_space: gymnasium.spaces.Box # set by the sub-class

def initial(self, rng: jax.random.PRNGKey) -> mjx._src.types.Data:
# TODO? find a more performant alternative that does not allocate?
mjx_data = mjx.make_data(self.model)
qpos, qvel = self._gen_init_state(rng)
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
mjx_data = mjx.forward(self.mjx_model, mjx_data)

return mjx_data

def transition(
self, state: mjx._src.types.Data, action: jnp.ndarray, rng=None
) -> mjx._src.types.Data:
"""Step through the simulator using `action` for `self.dt`."""
mjx_data = state
mjx_data = mjx_data.replace(ctrl=action)
mjx_data = jax.lax.fori_loop(
0, self.frame_skip, lambda _, x: mjx.step(self.mjx_model, x), mjx_data
)

return mjx_data
# TODO fix sensors with MJX>=3.1

def reward(
self,
state: mjx._src.types.Data,
action: jnp.ndarray,
next_state: mjx._src.types.Data,
) -> jnp.ndarray:
return self._get_reward(state, action, next_state)[0]

def transition_info(
self,
state: mjx._src.types.Data,
action: jnp.ndarray,
next_state: mjx._src.types.Data,
) -> dict:
return self._get_reward(state, action, next_state)[1]

def render_image(
self, state: mjx._src.types.Data, render_state: MujocoRenderer
) -> tuple[MujocoRenderer, np.ndarray | None]:
# NOTE function can not be jitted
mjx_data = state
mujoco_renderer = render_state

data = mujoco.MjData(self.model)
mjx.device_get_into(data, mjx_data) # TODO use get_data instead once mjx==3.1
mujoco.mj_forward(self.model, data)

mujoco_renderer.data = data

frame = mujoco_renderer.render(
self.render_mode, self.camera_id, self.camera_name
)

return mujoco_renderer, frame

def render_init(
self,
default_camera_config: dict[str, float] = {},
camera_id: int | None = None,
camera_name: str | None = None,
max_geom=1000,
width=480,
height=480,
render_mode="rgb_array",
) -> MujocoRenderer:
# TODO storing to much state? it should probably be moved internal to MujocoRenderer
self.render_mode = render_mode
self.camera_id = camera_id
self.camera_name = camera_name

return MujocoRenderer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that jax native too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it is not, and the render functions will not be JITable for a while (if ever), because of the structure of the deepming/mujoco project

self.model,
None,
default_camera_config,
width,
height,
max_geom,
)

def render_close(self, render_state: MujocoRenderer) -> None:
mujoco_renderer = render_state
if mujoco_renderer is not None:
mujoco_renderer.close()

@property
def dt(self) -> float:
return self.mjx_model.opt.timestep * self.frame_skip

def _gen_init_state(self, rng) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns: `(qpos, qvel)`
"""
# NOTE alternatives
# 1. return the state in a single vector
Kallinteris-Andreas marked this conversation as resolved.
Show resolved Hide resolved
# 2. return it a dictionary keyied by "qpos" & "qvel"
raise NotImplementedError

def _get_reward(
self,
state: mjx._src.types.Data,
action: jnp.ndarray,
next_state: mjx._src.types.Data,
) -> tuple[jnp.ndarray, dict]:
"""
Generates `reward` and `transition_info`, we rely on the JIT's SEE to optimize it.
Returns: `(reward, transition_info)`
"""
raise NotImplementedError

def observation(self, state: mjx._src.types.Data) -> jnp.ndarray:
raise NotImplementedError

def terminal(self, state: mjx._src.types.Data) -> bool:
raise NotImplementedError

def state_info(self, state: mjx._src.types.Data) -> dict:
raise NotImplementedError


# TODO in which file to place this class? in `half_cheetah_v5.py`?
Kallinteris-Andreas marked this conversation as resolved.
Show resolved Hide resolved
class HalfCheetahMJXEnv(MJXEnv, gymnasium.utils.EzPickle):
def __init__(
self,
xml_file: str = "half_cheetah.xml",
frame_skip: int = 5,
forward_reward_weight: float = 1.0,
ctrl_cost_weight: float = 0.1,
reset_noise_scale: float = 0.1,
exclude_current_positions_from_observation: bool = True,
**kwargs,
):
gymnasium.utils.EzPickle.__init__(
self,
xml_file,
frame_skip,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs,
)

self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

self._reset_noise_scale = reset_noise_scale

self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)

MJXEnv.__init__(
self,
model_path=xml_file,
frame_skip=frame_skip,
**kwargs,
)

obs_size = (
self.mjx_model.nq
+ self.mjx_model.nv
- exclude_current_positions_from_observation
)

self.observation_space = gymnasium.spaces.Box( # TODO use jnp when and if `Box` supports jax natively
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess we'll need to make spaces compatible with jax to be able to sample in jit compiled training function

low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float32
)

self.observation_structure = {
"skipped_qpos": 1 * exclude_current_positions_from_observation,
"qpos": self.mjx_model.nq - 1 * exclude_current_positions_from_observation,
"qvel": self.mjx_model.nv,
}

def _gen_init_state(self, rng) -> tuple[jnp.ndarray, jnp.ndarray]:
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale

qpos = self.mjx_model.qpos0 + jax.random.uniform(
key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,)
)
qvel = self._reset_noise_scale * jax.random.normal(
key=rng, shape=(self.mjx_model.nv,)
)

return qpos, qvel

def observation(self, state: mjx._src.types.Data) -> jnp.ndarray:
mjx_data = state
position = mjx_data.qpos.flatten()
velocity = mjx_data.qvel.flatten()

if self._exclude_current_positions_from_observation:
position = position[1:]

observation = jnp.concatenate((position, velocity))
return observation

def _get_reward(
self,
state: mjx._src.types.Data,
action: jnp.ndarray,
next_state: mjx._src.types.Data,
) -> tuple[jnp.ndarray, dict]:
mjx_data_old = state
mjx_data_new = next_state
x_position_before = mjx_data_old.qpos[0]
x_position_after = mjx_data_new.qpos[0]
x_velocity = (x_position_after - x_position_before) / self.dt

forward_reward = self._forward_reward_weight * x_velocity
ctrl_cost = self._ctrl_cost_weight * jnp.sum(jnp.square(action))

reward = forward_reward - ctrl_cost
reward_info = {
Kallinteris-Andreas marked this conversation as resolved.
Show resolved Hide resolved
"reward_forward": forward_reward,
"reward_ctrl": -ctrl_cost,
"x_velocity": x_velocity,
}

return reward, reward_info

def terminal(self, state: mjx._src.types.Data) -> bool:
return False
# NOTE or: return jnp.array(False)

def state_info(self, state: mjx._src.types.Data) -> dict:
mjx_data = state
x_position_after = mjx_data.qpos[0]
info = {
"x_position": x_position_after,
}
return info

def render_init(
self, default_camera_config: dict[str, float] = DEFAULT_CAMERA_CONFIG, **kwargs
) -> MujocoRenderer:
return super().render_init(
default_camera_config=default_camera_config, **kwargs
)


# TODO add vector environment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be super simple. It's essentially vmap step and reset

# TODO consider requirement of `metaworld` & `gymansium_robotics.RobotEnv`
Loading