From 69114bc39aed1ff0b56422621b658168c843717e Mon Sep 17 00:00:00 2001 From: Matthew Chang Date: Thu, 30 May 2024 22:26:20 +0000 Subject: [PATCH] Add new api for getting and setting object states which supports custom serialization --- .../habitat_simulator/object_state_machine.py | 109 +++++++++++++++++- test/test_object_state_machine.py | 39 ++++++- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py b/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py index 730d1dc9d3..a937cb0e27 100644 --- a/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py +++ b/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py @@ -156,6 +156,61 @@ def draw_state( :param camera_transform: The Matrix4 camera transform. """ + def get_state_of_obj( + self, obj: Union[ManagedArticulatedObject, ManagedRigidObject] + ) -> Any: + """ + Retrieves the state of a given object for this state. + + :param obj: The object to retrieve the state for. + :return: The state of the object for this state, or None if the object does not afford this state. + :rtype: Any + """ + if not self.is_affordance_of_obj(obj): + return None + + if "object_states" in obj.user_attributes.get_subconfig_keys(): + obj_states_config = obj.user_attributes.get_subconfig( + "object_states" + ) + if obj_states_config.has_value(self.name): + return self.deserialize_value(obj_states_config.get(self.name)) + return self.default_value() + + def set_state_of_obj( + self, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + value: Any, + ) -> None: + """ + Sets the state of a given object for this state spec. + + :param obj: The object to set the state for. + :param value: The value of the state to set. + """ + user_attr = obj.user_attributes + obj_state_config = user_attr.get_subconfig("object_states") + obj_state_config.set(self.name, self.serialize_value(value)) + user_attr.save_subconfig("object_states", obj_state_config) + + def serialize_value(self, value: Any) -> Any: + """ + Serializes the value of the state to be stored in the object states configuration. + Should be overridden by subclasses for complex/non-primitive state types. + :param value: The value of the state to set. + """ + return value + + def deserialize_value(self, config_value: Any) -> Any: + """ + Deserializes the value of the state loaded from the user_attributes config + Should be overridden by subclasses for complex/non-primitive state types. + + :param config_value: The value loaded from the user_attributes config. + :return: The deserialize value of the state. + """ + return config_value + class BooleanObjectState(ObjectStateSpec): """ @@ -214,6 +269,9 @@ def toggle( set_state_of_obj(obj, self.name, new_state) return new_state + def deserialize_value(self, config_value: Any) -> bool: + return bool(config_value) + class ObjectIsClean(BooleanObjectState): """ @@ -303,6 +361,55 @@ def update_states(self, sim: habitat_sim.Simulator, dt: float) -> None: for state in states: state.update_state(sim, obj, dt) + def _get_state_spec_from_name(self, state_name: str) -> ObjectStateSpec: + """ + Retrieves the state specification for a given state name. + + :param state_name: The name of the state. + :return: The state specification for the given state name. + """ + for state in self.active_states: + if state.name == state_name: + return state + raise ValueError(f"State {state_name} not found in active states.") + + def get_state_of_obj( + self, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + state_name: str, + ) -> Any: + """ + Retrieves the state of a given object for a specific state name. + + :param obj: The object to retrieve the state for. + :param state_name: The name of the state. + :return: The state of the object for the given state name, or None if the object does not afford this state. + :rtype: Any + """ + object_state_spec = self._get_state_spec_from_name(state_name) + if object_state_spec in self.objects_with_states[obj.handle]: + return object_state_spec.get_state_of_obj(obj) + else: + # Return none if the object does not afford this state + return None + + def set_state_of_obj( + self, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + state_name: str, + value: Any, + ) -> None: + """ + Sets the state of a given object for a specific state name. Does nothing if the object does not afford the state. + + :param obj: The object to set the state for. + :param state_name: The name of the state. + :param value: The value of the state to set. + """ + object_state_spec = self._get_state_spec_from_name(state_name) + if object_state_spec in self.objects_with_states[obj.handle]: + object_state_spec.set_state_of_obj(obj, value) + def get_snapshot_dict( self, sim: habitat_sim.Simulator ) -> Dict[str, Dict[str, Any]]: @@ -329,7 +436,7 @@ def get_snapshot_dict( for object_handle, states in self.objects_with_states.items(): obj = sutils.get_obj_from_handle(sim, object_handle) for state in states: - obj_state = get_state_of_obj(obj, state.name) + obj_state = self.get_state_of_obj(obj, state.name) snapshot[state.name][object_handle] = ( obj_state if obj_state is not None diff --git a/test/test_object_state_machine.py b/test/test_object_state_machine.py index 55975ccb30..7287d7d9b9 100644 --- a/test/test_object_state_machine.py +++ b/test/test_object_state_machine.py @@ -4,13 +4,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Union +import json +from typing import Any, Tuple, Union import magnum as mn from habitat.sims.habitat_simulator.object_state_machine import ( BooleanObjectState, ObjectStateMachine, + ObjectStateSpec, get_state_of_obj, set_state_of_obj, ) @@ -69,6 +71,23 @@ def update_state( set_state_of_obj(obj, self.name, False) +# TupleObjectState is a contrived example of a state that is a tuple +class TupleObjectState(ObjectStateSpec): + def __init__(self): + super().__init__() + self.name = "TupleObjectState" + self.accepted_semantic_classes = ["test_class"] + + def default_value(self) -> Tuple: + return (1, "test") + + def serialize_value(self, value: Tuple) -> str: + return json.dumps(value) + + def deserialize_value(self, config_value: Any) -> Any: + return tuple(json.loads(config_value)) + + def test_object_state_machine(): """ Test initializing and assigning a state to the state machine. @@ -92,7 +111,9 @@ def test_object_state_machine(): assert get_state_of_obj(new_obj, "semantic_class") == "test_class" # initialize the ObjectStateMachine - osm = ObjectStateMachine(active_states=[TestObjectState()]) + osm = ObjectStateMachine( + active_states=[TestObjectState(), TupleObjectState()] + ) osm.initialize_object_state_map(sim) # now the cube should be registered for TestObjectState because it has the correct semantic_class @@ -102,6 +123,20 @@ def test_object_state_machine(): osm.objects_with_states[new_obj.handle][0], TestObjectState ) + # The default value of the state should be True + assert osm.get_state_of_obj(new_obj, "TestState") == True + # Set the state to False + osm.set_state_of_obj(new_obj, "TestState", False) + # The return type should still be boolean and not integer + assert osm.get_state_of_obj(new_obj, "TestState") is False + # Set back to default value + osm.set_state_of_obj(new_obj, "TestState", True) + + # Test the TupleObjectState can read and write tuples + assert osm.get_state_of_obj(new_obj, "TupleObjectState") == (1, "test") + osm.set_state_of_obj(new_obj, "TupleObjectState", (8, "val")) + assert osm.get_state_of_obj(new_obj, "TupleObjectState") == (8, "val") + state_report_dict = osm.get_snapshot_dict(sim) assert "TestState" in state_report_dict assert new_obj.handle in state_report_dict["TestState"]