Skip to content

Commit

Permalink
Fixed MjSpec introspection with visual.rgba and visual.headlight. Add…
Browse files Browse the repository at this point in the history
…ed access to MjSpec.visual.global_.
  • Loading branch information
AaronYoung5 committed Dec 14, 2024
1 parent b26d6f0 commit ba25a69
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/mujoco/codegen/generate_spec_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,13 @@ def _binding_code(field: ast_nodes.StructFieldDecl, key: str) -> str:
if isinstance(field.type, ast_nodes.ValueType):
return _value_binding_code(field.type, key, field.name)
elif isinstance(field.type, ast_nodes.AnonymousStructDecl):
code = ""
if field.name in ['headlight', 'rgba']:
for subfield in field.type.fields:
code += _binding_code(subfield, 'mjVisual'+field.name.title())
field.type = ast_nodes.ValueType(name='mjVisual'+field.name.title())
return _value_binding_code(field.type, key, field.name)
code += _value_binding_code(field.type, key, field.name)
return code
elif isinstance(field.type, ast_nodes.PointerType):
return _ptr_binding_code(field.type, key, field.name)
elif isinstance(field.type, ast_nodes.ArrayType):
Expand Down
9 changes: 9 additions & 0 deletions python/mujoco/specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ PYBIND11_MODULE(_specs, m) {
py::class_<raw::MjOption> mjOption(m, "MjOption");
py::class_<raw::MjStatistic> mjStatistic(m, "MjStatistic");
py::class_<raw::MjVisual> mjVisual(m, "MjVisual");
py::class_<raw::MjVisualHeadlight> mjVisualHeadlight(m, "MjVisualHeadlight");
py::class_<raw::MjVisualRgba> mjVisualRgba(m, "MjVisualRgba");
py::class_<raw::MjsCompiler> mjsCompiler(m, "MjsCompiler");
DefineArray<char>(m, "MjCharVec");
DefineArray<std::string>(m, "MjStringVec");
Expand Down Expand Up @@ -979,6 +981,13 @@ PYBIND11_MODULE(_specs, m) {
});
mjsPlugin.def("delete",
[](raw::MjsPlugin& self) { mjs_delete(self.element); });
// ============================= MJVISUAL ====================================
mjVisual.def_property(
"global_",
[](raw::MjVisual& self) -> raw::MjVisualGlobal& { return self.global; },
[](raw::MjVisual& self, raw::MjVisualGlobal& value) {
self.global = value;
});

#include "specs.cc.inc"
} // PYBIND11_MODULE // NOLINT
Expand Down
9 changes: 9 additions & 0 deletions python/mujoco/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,22 +848,31 @@ def test_access_option_stat_visual(self):
<statistic meansize="0.05"/>
<visual>
<quality shadowsize="4096"/>
<headlight active="0"/>
<rgba camera="0 0 0 0"/>
</visual>
</mujoco>
""")
self.assertEqual(spec.option.timestep, 0.001)
self.assertEqual(spec.stat.meansize, 0.05)
self.assertEqual(spec.visual.quality.shadowsize, 4096)
self.assertEqual(spec.visual.headlight.active, 0)
self.assertEqual(spec.visual.global_, getattr(spec.visual, 'global'))
np.testing.assert_array_equal(spec.visual.rgba.camera, [0, 0, 0, 0])

spec.option.timestep = 0.002
spec.stat.meansize = 0.06
spec.visual.quality.shadowsize = 8192
spec.visual.headlight.active = 1
spec.visual.rgba.camera = [1, 1, 1, 1]

model = spec.compile()

self.assertEqual(model.opt.timestep, 0.002)
self.assertEqual(model.stat.meansize, 0.06)
self.assertEqual(model.vis.quality.shadowsize, 8192)
self.assertEqual(model.vis.headlight.active, 1)
np.testing.assert_array_equal(model.vis.rgba.camera, [1, 1, 1, 1])

def test_assign_list_element(self):
spec = mujoco.MjSpec()
Expand Down

0 comments on commit ba25a69

Please sign in to comment.